|
9 | 9 | from mlir.dialects.scf import IfOp, ForOp |
10 | 10 | from mlir.ir import InsertionPoint, Value, OpResultList, OpResult |
11 | 11 |
|
| 12 | +import mlir_utils.types as T |
12 | 13 | from mlir_utils.ast.canonicalize import ( |
13 | 14 | StrictTransformer, |
14 | 15 | Canonicalizer, |
|
18 | 19 | from mlir_utils.ast.util import ast_call |
19 | 20 | from mlir_utils.dialects.ext.arith import constant |
20 | 21 | from mlir_utils.dialects.scf import yield_ as yield__ |
21 | | -from mlir_utils.types import opaque_t |
22 | 22 | from mlir_utils.util import ( |
23 | 23 | region_op, |
24 | 24 | maybe_cast, |
@@ -101,8 +101,6 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None): |
101 | 101 |
|
102 | 102 | if_ = region_op(_if, terminator=yield__) |
103 | 103 |
|
104 | | -_placeholder_opaque_t = opaque_t("scf", "placeholder") |
105 | | - |
106 | 104 |
|
107 | 105 | class IfStack: |
108 | 106 | __current_if_op: list[IfOp] = [] |
@@ -175,7 +173,7 @@ def yield_(*args): |
175 | 173 |
|
176 | 174 | assert len(results) == len(unpacked_args), f"{results=}, {unpacked_args=}" |
177 | 175 | for i, r in enumerate(results): |
178 | | - if r.type == _placeholder_opaque_t: |
| 176 | + if r.type == T._placeholder_opaque_t(): |
179 | 177 | r.set_type(unpacked_args[i].type) |
180 | 178 |
|
181 | 179 | yield_(*args) |
@@ -325,13 +323,13 @@ def insert_with_results( |
325 | 323 | ), f"conditional with := must explicitly yield on last line" |
326 | 324 | yield_expr = last_statement.body[0] |
327 | 325 | if m.matches(yield_expr.value, m.Call(func=m.Name(stack_yield.__name__))): |
328 | | - results = [cst.Element(cst.Name("_placeholder_opaque_t"))] * len( |
| 326 | + results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len( |
329 | 327 | yield_expr.value.args |
330 | 328 | ) |
331 | 329 | elif m.matches(yield_expr.value.value, m.Name()): |
332 | | - results = [cst.Element(cst.Name("_placeholder_opaque_t"))] |
| 330 | + results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] |
333 | 331 | elif m.matches(yield_expr.value.value, m.Tuple()): |
334 | | - results = [cst.Element(cst.Name("_placeholder_opaque_t"))] * len( |
| 332 | + results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len( |
335 | 333 | yield_expr.value.value.elements |
336 | 334 | ) |
337 | 335 | results = cst.Tuple(results) |
@@ -422,14 +420,13 @@ def patch_bytecode(self, code: ConcreteBytecode, f): |
422 | 420 | str(OpCode.NOP), lineno=c.lineno, location=c.location |
423 | 421 | ) |
424 | 422 |
|
425 | | - # TODO(max): this is bad |
426 | 423 | f.__globals__[else_.__name__] = else_ |
427 | 424 | f.__globals__[end_branch.__name__] = end_branch |
428 | 425 | f.__globals__[end_if.__name__] = end_if |
429 | 426 | f.__globals__[stack_if.__name__] = stack_if |
430 | 427 | f.__globals__[stack_yield.__name__] = stack_yield |
431 | 428 | f.__globals__[yield_.__name__] = yield_ |
432 | | - f.__globals__["_placeholder_opaque_t"] = _placeholder_opaque_t |
| 429 | + f.__globals__[T._placeholder_opaque_t.__name__] = T._placeholder_opaque_t |
433 | 430 | return code |
434 | 431 |
|
435 | 432 |
|
|
0 commit comments