Skip to content

Commit 45f2e57

Browse files
Sync fx_importer from torch-mlir. (#475)
I was hoping that we could just depend on it via IREE, but it needed some local patches (sending upstream) for type imports that don't exist in old versions of PyTorch. So for now, just updating the local fork.
1 parent fabd52c commit 45f2e57

File tree

2 files changed

+800
-94
lines changed

2 files changed

+800
-94
lines changed

core/shark_turbine/aot/builtins/jittable.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,21 @@
2222
)
2323
from torch.fx.passes.shape_prop import TensorMetadata
2424

25-
from ...dynamo.passes import (
26-
DEFAULT_DECOMPOSITIONS,
27-
)
25+
# TODO: Switch to upstream fx_importer vs local fork when ready.
26+
# from iree.compiler.extras.fx_importer import (
27+
# GraphNodeImporter,
28+
# FxImporter,
29+
# FxImporterHooks,
30+
# )
2831

2932
from ...importers.fx_importer import (
3033
GraphNodeImporter,
3134
FxImporter,
35+
FxImporterHooks,
36+
)
37+
38+
from ...dynamo.passes import (
39+
DEFAULT_DECOMPOSITIONS,
3240
)
3341

3442
from ...support.ir_imports import (
@@ -68,23 +76,33 @@
6876
StringAttrOrStr = Union[StringAttr, str]
6977

7078

71-
def _make_literal_resolver(module_builder: ModuleBuilder):
72-
# When we first encounter a global during import, we have to pull it
73-
# into the local module being populated by the GraphNodeImporter. This
74-
# will exactly match the global in the target module we are merging into
75-
# and exists so that the IR is valid during Fx import. We keep the set of
76-
# symbols we have done this to here.
77-
cloned_global_symbols: Set[str] = set()
79+
class _Hooks(FxImporterHooks):
80+
__slots__ = [
81+
"cloned_global_symbols",
82+
"module_builder",
83+
]
84+
85+
def __init__(self, module_builder: ModuleBuilder):
86+
self.module_builder = module_builder
87+
# When we first encounter a global during import, we have to pull it
88+
# into the local module being populated by the GraphNodeImporter. This
89+
# will exactly match the global in the target module we are merging into
90+
# and exists so that the IR is valid during Fx import. We keep the set of
91+
# symbols we have done this to here.
92+
self.cloned_global_symbols: set[str] = set()
93+
94+
def resolve_literal(self, gni: GraphNodeImporter, literal: Any) -> Optional[Value]:
95+
module_builder = self.module_builder
96+
cloned_global_symbols = self.cloned_global_symbols
7897

79-
def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
8098
# We support resolution of tracked reference types. Currently this
8199
# only includes Tensors. All others we let the importer do what it
82100
# is going to do.
83-
if not isinstance(py_value, torch.Tensor):
101+
if not isinstance(literal, torch.Tensor):
84102
return None
85103

86104
# See if we know about it.
87-
mapping = module_builder.global_ref_tracker.track(py_value)
105+
mapping = module_builder.global_ref_tracker.track(literal)
88106
if mapping.is_empty:
89107
# If it is unknown, just let the default importer take it on.
90108
return None
@@ -101,7 +119,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
101119
cloned_global_symbols.add(materialized_global.symbol_name)
102120

103121
# Emit a global load and conversion.
104-
vtensor_type = gni._cc.tensor_to_vtensor_type(py_value)
122+
vtensor_type = gni._cc.tensor_to_vtensor_type(literal)
105123
loaded_value = util_d.GlobalLoadOp(
106124
materialized_global.ir_type, materialized_global.symbol_name
107125
).result
@@ -112,8 +130,6 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
112130
).result
113131
return converted_value
114132

115-
return resolver
116-
117133

118134
ALL_PASSES: Set[str] = set(["functorch_functionalize"])
119135
DEFAULT_PASSES: Tuple[str, ...] = ("functorch_functionalize",)
@@ -234,7 +250,7 @@ def flat_wrapped_f(*args):
234250
fx_importer = FxImporter(
235251
context=proc_trace.context,
236252
config_check=False,
237-
literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder),
253+
hooks=_Hooks(proc_trace.module_builder),
238254
py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker,
239255
)
240256
fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name)

0 commit comments

Comments
 (0)