Skip to content

Commit a472d1a

Browse files
authored
[Wave] Migrate iree.turbine.transforms -> wave.transforms (#40)
Signed-off-by: Harsh Menon <[email protected]>
1 parent 83f111d commit a472d1a

File tree

5 files changed

+13
-17
lines changed

5 files changed

+13
-17
lines changed

iree/turbine/aot/compiled_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
StringAttr,
3030
)
3131
from ..support.logging import aot_logger as logger
32-
from ..transforms.general.custom_op_expansion import ExpandCustomOpsPass
32+
from wave.transforms.general.custom_op_expansion import ExpandCustomOpsPass
3333

3434
from .support.procedural import (
3535
GlobalsDef,

iree/turbine/runtime/op_reg/impl_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def generate(kb: KernelBuilder):
3636
Value,
3737
)
3838

39-
from ...transforms.merger import Merger
39+
from wave.transforms.merger import Merger
4040

4141
from .base import (
4242
KernelBuilder,

lit_tests/kernel/wave/sharktank_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import textwrap
44
from typing import Optional
5+
from wave.transforms.merger import Merger
56

67
import torch
78
from jinja2 import BaseLoader, Environment
@@ -43,7 +44,6 @@
4344
from iree.turbine.runtime.op_reg.impl_helper import (
4445
call_function,
4546
)
46-
from iree.turbine.transforms.merger import Merger
4747

4848
_JINJA2_ENVIRONMENT: Optional[Environment] = None
4949

iree/turbine/transforms/general/custom_op_expansion.py renamed to wave/transforms/general/custom_op_expansion.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,37 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
from typing import Callable, Optional
8+
from wave.dynamo.type_conversion import (
9+
NativeTypeConverter,
10+
)
811

912
import torch
1013
from torch import Tensor
1114
from torch._subclasses.fake_tensor import FakeTensorMode
1215
from torch.fx.experimental.symbolic_shapes import ShapeEnv
1316

14-
from wave.dynamo.type_conversion import (
15-
NativeTypeConverter,
16-
)
17-
18-
from ...runtime.op_reg.base import (
17+
from iree.turbine.runtime.op_reg.base import (
1918
ALL_CUSTOM_OP_REGS,
2019
AttrArg,
20+
CustomOp,
2121
EmptyOptionalTensorArg,
2222
IntArg,
23-
CustomOp,
2423
KernelBuilder,
2524
KernelSelection,
2625
TensorArg,
2726
TensorListArg,
2827
)
29-
30-
from ...support.conversions import (
28+
from iree.turbine.support.conversions import (
3129
MLIR_TYPE_ASM_TO_TORCH_DTYPE,
3230
)
33-
34-
from ...support.ir_imports import (
31+
from iree.turbine.support.ir_imports import (
3532
Block,
3633
FloatAttr,
37-
IrType,
3834
InsertionPoint,
3935
IntegerAttr,
40-
OpResult,
36+
IrType,
4137
Operation,
38+
OpResult,
4239
RankedTensorType,
4340
StringAttr,
4441
SymbolTable,
@@ -70,7 +67,7 @@ def funcs(self):
7067

7168
def erase_unused_op(self, op: Operation):
7269
"""Recursively erases any unused torch ops, starting with op."""
73-
from ...support.ir_imports import OpResult
70+
from iree.turbine.support.ir_imports import OpResult
7471

7572
worklist = set()
7673
worklist.add(op)

iree/turbine/transforms/merger.py renamed to wave/transforms/merger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
SymbolTable,
1515
)
1616

17-
1817
__all__ = [
1918
"Merger",
2019
]

0 commit comments

Comments
 (0)