Skip to content

Commit 098dbb8

Browse files
authored
[Wave] Migrate iree.turbine.aot -> wave.aot (#45)
Also run ruff on the code to make sure it passes pre-commit. Signed-off-by: Harsh Menon <[email protected]>
1 parent 6364119 commit 098dbb8

File tree

24 files changed

+439
-401
lines changed

24 files changed

+439
-401
lines changed

docs/core/aot.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
`iree.turbine.aot`
1+
`wave.aot`
22
=====================
33

44
aot
55
--------------
66

7-
.. automodule:: iree.turbine.aot
7+
.. automodule:: wave.aot
88
:imported-members:
99
:members:
1010
:undoc-members:
1111

1212
passes
1313
--------------
1414

15-
.. automodule:: iree.turbine.aot.passes
15+
.. automodule:: wave.aot.passes
1616
:imported-members:
1717
:members:
1818
:undoc-members:
1919

2020
support
2121
--------------
2222

23-
.. automodule:: iree.turbine.aot.support.procedural
23+
.. automodule:: wave.aot.support.procedural
2424
:imported-members:
2525
:members:
2626
:undoc-members:
2727

28-
.. automodule:: iree.turbine.aot.support.procedural.exported_program
28+
.. automodule:: wave.aot.support.procedural.exported_program
2929
:members:
3030
:undoc-members:
3131

3232
build actions
3333
--------------
3434

35-
.. py:module:: iree.turbine.aot.build_actions
35+
.. py:module:: wave.aot.build_actions
3636
3737
.. autofunction:: turbine_generate
3838
.. autoclass:: RemoteGenerator

iree/turbine/kernel/wave/codegen/emitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.fx as fx
1515

1616
from iree.turbine.kernel.lang.global_symbols import *
17-
from iree.turbine.aot.support.ir_utils import (
17+
from wave.aot.support.ir_utils import (
1818
_is_float_type,
1919
_is_integer_like_type,
2020
)

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
scf_d,
3838
vector_d,
3939
)
40-
from iree.turbine.aot.support.ir_utils import (
40+
from wave.aot.support.ir_utils import (
4141
_is_float_type,
4242
_is_integer_like_type,
4343
_is_signed_or_signless_type,

lit_tests/kernel/wave/sharktank_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# RUN: python %s | FileCheck %s
22

33
import textwrap
4+
import wave.aot as aot
45
from typing import Optional
56
from wave.transforms.merger import Merger
67

78
import torch
89
from jinja2 import BaseLoader, Environment
910

10-
import iree.turbine.aot as aot
1111
import iree.turbine.kernel.wave as tkw
1212
from iree.compiler.ir import (
1313
Context,

iree/turbine/aot/__init__.py renamed to wave/aot/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
"""
2-
Toolkit for ahead-of-time (AOT) compilation and export of PyTorch programs.
3-
"""
1+
"""Toolkit for ahead-of-time (AOT) compilation and export of PyTorch programs."""
42

53
# Copyright 2023 Nod Labs, Inc
64
#
@@ -13,5 +11,10 @@
1311
from .decompositions import *
1412
from .exporter import *
1513
from .fx_programs import FxPrograms, FxProgramsBuilder
16-
from .tensor_traits import *
1714
from .params import *
15+
from .tensor_traits import *
16+
17+
__all__ = [
18+
"FxPrograms",
19+
"FxProgramsBuilder",
20+
]

iree/turbine/aot/builtins/__init__.py renamed to wave/aot/builtins/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from .globals import *
8-
from .jittable import jittable
97
from ..support.procedural import (
108
AbstractBool,
119
AbstractF32,
@@ -19,23 +17,25 @@
1917

2018
# Export the instantiated IREEEmitter as "IREE"
2119
from ..support.procedural.iree_emitter import IREEEmitter as _IREEEmitter
20+
from .globals import *
21+
from .jittable import jittable
2222

2323
IREE = _IREEEmitter()
2424
del _IREEEmitter
2525

2626
__all__ = [
27+
"IREE",
2728
"AbstractBool",
2829
"AbstractF32",
2930
"AbstractF64",
3031
"AbstractI32",
3132
"AbstractI64",
3233
"AbstractIndex",
3334
"AbstractTensor",
34-
"IREE",
3535
"abstractify",
36+
"export_buffers",
3637
"export_global",
3738
"export_global_tree",
3839
"export_parameters",
39-
"export_buffers",
4040
"jittable",
4141
]

iree/turbine/aot/builtins/globals.py renamed to wave/aot/builtins/globals.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,39 @@
55
# See https://llvm.org/LICENSE.txt for license information.
66
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
77

8-
from typing import Any, Callable, Optional
8+
from typing import Any, Optional
99

10-
import torch.nn as nn
10+
from torch import nn
11+
from torch.utils._pytree import (
12+
TreeSpec,
13+
tree_flatten,
14+
tree_map,
15+
)
1116

17+
from ..support.ir_utils import (
18+
GlobalAttributes,
19+
NameMapCallback,
20+
)
1221
from ..support.procedural import (
13-
AbstractTypedef,
1422
Abstractifiable,
23+
AbstractTypedef,
1524
GlobalsDef,
1625
TreeAbstractifiable,
1726
abstractify_single_value,
1827
)
1928

20-
from ..support.ir_utils import (
21-
NameMapCallback,
22-
GlobalAttributes,
23-
)
24-
25-
from torch.utils._pytree import (
26-
TreeSpec,
27-
tree_flatten,
28-
tree_map,
29-
)
30-
31-
3229
__all__ = [
30+
"export_buffers",
3331
"export_global",
3432
"export_global_tree",
3533
"export_parameters",
36-
"export_buffers",
3734
]
3835

3936

4037
class export_global(GlobalsDef, Abstractifiable):
4138
"""Exports a single global into a CompiledModule."""
4239

43-
__slots__ = ["_name", "_value", "_schema"]
40+
__slots__ = ["_name", "_schema", "_value"]
4441

4542
def __init__(
4643
self,
@@ -104,7 +101,7 @@ def __init__(
104101
self._items, self._schema = tree_flatten(tree)
105102
self._names, _ = tree_flatten(_transform_tree_to_names("", tree))
106103
assert len(self._items) == len(
107-
self._names
104+
self._names,
108105
), f"Name and value tree are different sizes: {len(self._items)} != {len(self._names)}"
109106

110107
def items(self):
@@ -235,10 +232,9 @@ def _transform_tree_to_names(prefix: str, tree):
235232
return tree.__class__(
236233
(k, _transform_tree_to_names(join(k), v)) for k, v in tree.items()
237234
)
238-
elif isinstance(tree, (list, tuple)):
235+
if isinstance(tree, (list, tuple)):
239236
return tree.__class__(
240237
_transform_tree_to_names(join(str(index)), v)
241238
for index, v in enumerate(tree)
242239
)
243-
else:
244-
return prefix
240+
return prefix

iree/turbine/aot/builtins/jittable.py renamed to wave/aot/builtins/jittable.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77

88
"""Tracing builtins."""
99

10-
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
11-
12-
import warnings
10+
from collections.abc import Sequence
11+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
1312

1413
import torch
15-
from torch._decomp import get_decompositions
1614
import torch._dynamo as dynamo
15+
from torch._decomp import get_decompositions
1716
from torch.fx import (
1817
GraphModule,
1918
)
@@ -23,13 +22,12 @@
2322
)
2423

2524
from iree.compiler.extras.fx_importer import (
26-
GraphNodeImporter,
2725
FxImporter,
2826
FxImporterHooks,
27+
GraphNodeImporter,
2928
InputInfo,
3029
)
31-
32-
from ...support.ir_imports import (
30+
from iree.turbine.support.ir_imports import (
3331
FlatSymbolRefAttr,
3432
FunctionType,
3533
Operation,
@@ -40,18 +38,15 @@
4038
func_d,
4139
util_d,
4240
)
43-
44-
from ...support.logging import aot_logger as logger
41+
from iree.turbine.support.logging import aot_logger as logger
4542

4643
from ..decompositions import current_aot_decompositions
4744
from ..passes import (
4845
functorch_functionalize,
4946
)
50-
5147
from ..support.ir_utils import (
5248
ModuleBuilder,
5349
)
54-
5550
from ..support.procedural import (
5651
CallableIntrinsic,
5752
IrImmediateTensor,
@@ -79,7 +74,10 @@ def __init__(self, module_builder: ModuleBuilder):
7974
self.cloned_global_symbols: set[str] = set()
8075

8176
def resolve_literal(
82-
self, gni: GraphNodeImporter, literal: Any, info: Optional[InputInfo] = None
77+
self,
78+
gni: GraphNodeImporter,
79+
literal: Any,
80+
info: Optional[InputInfo] = None,
8381
) -> Optional[Value]:
8482
module_builder = self.module_builder
8583
cloned_global_symbols = self.cloned_global_symbols
@@ -110,7 +108,8 @@ def resolve_literal(
110108
# Emit a global load and conversion.
111109
vtensor_type = gni._cc.tensor_to_vtensor_type(literal)
112110
loaded_value = util_d.GlobalLoadOp(
113-
materialized_global.ir_type, materialized_global.symbol_name
111+
materialized_global.ir_type,
112+
materialized_global.symbol_name,
114113
).result
115114
converted_value = Operation.create(
116115
"torch_c.from_builtin_tensor",
@@ -131,11 +130,11 @@ class jittable(CallableIntrinsic):
131130
"""
132131

133132
__slots__ = [
134-
"dynamic_shapes",
133+
"_passes",
135134
"decomposition_table",
136-
"wrapped_f",
135+
"dynamic_shapes",
137136
"function_name",
138-
"_passes",
137+
"wrapped_f",
139138
]
140139

141140
def __init__(
@@ -266,10 +265,10 @@ def flat_wrapped_f(*args):
266265
# TODO: Debug upstream why iteration over children isn't creating a typed view.
267266
# This should just be `target_op.function_type`
268267
target_ftype = FunctionType(
269-
TypeAttr(target_op.attributes["function_type"]).value
268+
TypeAttr(target_op.attributes["function_type"]).value,
270269
)
271270
target_symbol_ref = FlatSymbolRefAttr.get(
272-
StringAttr(target_op.attributes["sym_name"]).value
271+
StringAttr(target_op.attributes["sym_name"]).value,
273272
)
274273

275274
assert len(flat_ir_args) == len(target_ftype.inputs), (
@@ -287,7 +286,9 @@ def flat_wrapped_f(*args):
287286

288287
with proc_trace.ip, proc_trace.loc:
289288
flat_ir_results = func_d.CallOp(
290-
target_ftype.results, target_symbol_ref, flat_ir_args
289+
target_ftype.results,
290+
target_symbol_ref,
291+
flat_ir_args,
291292
).results
292293

293294
assert len(flat_ir_results) == len(result_tensor_infos)
@@ -300,7 +301,7 @@ def flat_wrapped_f(*args):
300301
flat_py_results.append(IrImmediateTensor(native_ir_result, dtype))
301302
else:
302303
raise TypeError(
303-
f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}"
304+
f"Unknown PyTorch->IREE value mapping for jittable result: {result_tensor_info}->{native_ir_result}",
304305
)
305306

306307
tree_py_results = tree_unflatten(flat_py_results, out_spec)
@@ -317,14 +318,14 @@ def _split_py_arg(self, arg) -> Tuple[Value, Any]:
317318
class _Merger:
318319
__slots__ = [
319320
"context",
320-
"to_module_builder",
321321
"from_module_op",
322322
"from_symbol_table",
323323
"import_function_name",
324-
"rename_map",
325324
"nested_symbol_ops",
326325
"nested_symbol_table_ops",
327326
"private_attr",
327+
"rename_map",
328+
"to_module_builder",
328329
]
329330

330331
def __init__(
@@ -440,7 +441,7 @@ def _extract_graph_output_metadata(
440441
out_spec = gm._out_spec
441442
except AttributeError:
442443
raise AssertionError(
443-
"Expected PyTorch to add an _out_spec attribute to the GraphModule"
444+
"Expected PyTorch to add an _out_spec attribute to the GraphModule",
444445
)
445446

446447
output_nodes = []

0 commit comments

Comments
 (0)