|
45 | 45 | Value, |
46 | 46 | ) |
47 | 47 |
|
48 | | -from ..rewriter import ( |
49 | | - Pass, |
50 | | -) |
| 48 | + |
| 49 | +class Pass: |
| 50 | + """Minimal Pass base class for custom op expansion.""" |
| 51 | + |
| 52 | + def __init__(self, root_op: Operation): |
| 53 | + self.root_op = root_op |
| 54 | + |
| 55 | + def run(self): |
| 56 | + raise NotImplementedError |
| 57 | + |
| 58 | + @property |
| 59 | + def funcs(self): |
| 60 | + """Get all func.func operations in the module.""" |
| 61 | + results = [] |
| 62 | + # Traverse all regions and blocks to find func.func operations |
| 63 | + for region in self.root_op.regions: |
| 64 | + for block in region.blocks: |
| 65 | + for op in block.operations: |
| 66 | + actual_op = op.operation |
| 67 | + if actual_op.name == "func.func": |
| 68 | + results.append(type("OpMatchResult", (), {"op": op})()) |
| 69 | + return results |
| 70 | + |
| 71 | + def erase_unused_op(self, op: Operation): |
| 72 | + """Recursively erases any unused torch ops, starting with op.""" |
| 73 | + from ...support.ir_imports import OpResult |
| 74 | + |
| 75 | + worklist = set() |
| 76 | + worklist.add(op) |
| 77 | + while worklist: |
| 78 | + ops = worklist |
| 79 | + worklist = set() |
| 80 | + for op in ops: |
| 81 | + if not self._is_erasable_value_op(op): |
| 82 | + continue |
| 83 | + if not self._op_is_live(op): |
| 84 | + for operand in op.operands: |
| 85 | + if OpResult.isinstance(operand): |
| 86 | + worklist.add(operand.owner) |
| 87 | + op.erase() |
| 88 | + |
| 89 | + def _is_erasable_value_op(self, op: Operation): |
| 90 | + name = op.name |
| 91 | + return name.startswith("torch.") or name.startswith("torch_c.") |
| 92 | + |
| 93 | + def _op_is_live(self, op: Operation) -> bool: |
| 94 | + for r in op.results: |
| 95 | + try: |
| 96 | + next(r.uses) |
| 97 | + return True |
| 98 | + except StopIteration: |
| 99 | + pass |
| 100 | + return False |
51 | 101 |
|
52 | 102 |
|
53 | 103 | class ExpandCustomOpsPass(Pass): |
@@ -173,7 +223,7 @@ def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> TensorArg: |
173 | 223 | element_type_asm = str(rtt.element_type) |
174 | 224 | try: |
175 | 225 | dtype = MLIR_TYPE_ASM_TO_TORCH_DTYPE[element_type_asm] |
176 | | - except KeyError as e: |
| 226 | + except KeyError: |
177 | 227 | raise AssertionError( |
178 | 228 | f"Could not find dtype mapping for {element_type_asm} in MLIR_TYPE_ASM_TO_TORCH_DTYPE" |
179 | 229 | ) |
@@ -410,7 +460,7 @@ def yield_results(self, *results: Value): |
410 | 460 | torch_op_results: list[Value] = list(self.torch_op.results) |
411 | 461 | assert len(results) == len( |
412 | 462 | torch_op_results |
413 | | - ), f"Mismatched yield_results with custom op results" |
| 463 | + ), "Mismatched yield_results with custom op results" |
414 | 464 | for new_result, old_result in zip(results, torch_op_results): |
415 | 465 | torch_type = old_result.type |
416 | 466 | new_result = self.type_converter.materialize_native_to_torch( |
|
0 commit comments