Skip to content

Commit 83f111d

Browse files
authored
[Wave] Remove unused code in iree/turbine/transforms (#39)
This PR removes unused code in iree/turbine/transforms. The Pass class is copied from rewriter.py with some modifications to make the tests pass. Signed-off-by: Harsh Menon <[email protected]>
1 parent 4f27acf commit 83f111d

File tree

7 files changed

+56
-849
lines changed

7 files changed

+56
-849
lines changed

iree/turbine/transforms/builder.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

iree/turbine/transforms/general/add_metadata.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

iree/turbine/transforms/general/custom_op_expansion.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,59 @@
4545
Value,
4646
)
4747

48-
from ..rewriter import (
49-
Pass,
50-
)
48+
49+
class Pass:
50+
"""Minimal Pass base class for custom op expansion."""
51+
52+
def __init__(self, root_op: Operation):
53+
self.root_op = root_op
54+
55+
def run(self):
56+
raise NotImplementedError
57+
58+
@property
59+
def funcs(self):
60+
"""Get all func.func operations in the module."""
61+
results = []
62+
# Traverse all regions and blocks to find func.func operations
63+
for region in self.root_op.regions:
64+
for block in region.blocks:
65+
for op in block.operations:
66+
actual_op = op.operation
67+
if actual_op.name == "func.func":
68+
results.append(type("OpMatchResult", (), {"op": op})())
69+
return results
70+
71+
def erase_unused_op(self, op: Operation):
72+
"""Recursively erases any unused torch ops, starting with op."""
73+
from ...support.ir_imports import OpResult
74+
75+
worklist = set()
76+
worklist.add(op)
77+
while worklist:
78+
ops = worklist
79+
worklist = set()
80+
for op in ops:
81+
if not self._is_erasable_value_op(op):
82+
continue
83+
if not self._op_is_live(op):
84+
for operand in op.operands:
85+
if OpResult.isinstance(operand):
86+
worklist.add(operand.owner)
87+
op.erase()
88+
89+
def _is_erasable_value_op(self, op: Operation):
90+
name = op.name
91+
return name.startswith("torch.") or name.startswith("torch_c.")
92+
93+
def _op_is_live(self, op: Operation) -> bool:
94+
for r in op.results:
95+
try:
96+
next(r.uses)
97+
return True
98+
except StopIteration:
99+
pass
100+
return False
51101

52102

53103
class ExpandCustomOpsPass(Pass):
@@ -173,7 +223,7 @@ def arg_tensor(self, arg: int, *, inplace_tied: bool = False) -> TensorArg:
173223
element_type_asm = str(rtt.element_type)
174224
try:
175225
dtype = MLIR_TYPE_ASM_TO_TORCH_DTYPE[element_type_asm]
176-
except KeyError as e:
226+
except KeyError:
177227
raise AssertionError(
178228
f"Could not find dtype mapping for {element_type_asm} in MLIR_TYPE_ASM_TO_TORCH_DTYPE"
179229
)
@@ -410,7 +460,7 @@ def yield_results(self, *results: Value):
410460
torch_op_results: list[Value] = list(self.torch_op.results)
411461
assert len(results) == len(
412462
torch_op_results
413-
), f"Mismatched yield_results with custom op results"
463+
), "Mismatched yield_results with custom op results"
414464
for new_result, old_result in zip(results, torch_op_results):
415465
torch_type = old_result.type
416466
new_result = self.type_converter.materialize_native_to_torch(

iree/turbine/transforms/general/rename_parameters.py

Lines changed: 0 additions & 156 deletions
This file was deleted.

iree/turbine/transforms/merger.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from typing import Any, Dict, List, Optional, Sequence, Union
7+
from typing import Dict, List, Optional, Sequence
88

99
from iree.compiler.ir import (
1010
Attribute,
11-
Block,
1211
InsertionPoint,
1312
Operation,
1413
StringAttr,
1514
SymbolTable,
16-
Context,
1715
)
1816

1917

0 commit comments

Comments
 (0)