22
22
)
23
23
from torch .fx .passes .shape_prop import TensorMetadata
24
24
25
- from ...dynamo .passes import (
26
- DEFAULT_DECOMPOSITIONS ,
27
- )
25
+ # TODO: Switch to upstream fx_importer vs local fork when ready.
26
+ # from iree.compiler.extras.fx_importer import (
27
+ # GraphNodeImporter,
28
+ # FxImporter,
29
+ # FxImporterHooks,
30
+ # )
28
31
29
32
from ...importers .fx_importer import (
30
33
GraphNodeImporter ,
31
34
FxImporter ,
35
+ FxImporterHooks ,
36
+ )
37
+
38
+ from ...dynamo .passes import (
39
+ DEFAULT_DECOMPOSITIONS ,
32
40
)
33
41
34
42
from ...support .ir_imports import (
68
76
StringAttrOrStr = Union [StringAttr , str ]
69
77
70
78
71
- def _make_literal_resolver (module_builder : ModuleBuilder ):
72
- # When we first encounter a global during import, we have to pull it
73
- # into the local module being populated by the GraphNodeImporter. This
74
- # will exactly match the global in the target module we are merging into
75
- # and exists so that the IR is valid during Fx import. We keep the set of
76
- # symbols we have done this to here.
77
- cloned_global_symbols : Set [str ] = set ()
79
+ class _Hooks (FxImporterHooks ):
80
+ __slots__ = [
81
+ "cloned_global_symbols" ,
82
+ "module_builder" ,
83
+ ]
84
+
85
+ def __init__ (self , module_builder : ModuleBuilder ):
86
+ self .module_builder = module_builder
87
+ # When we first encounter a global during import, we have to pull it
88
+ # into the local module being populated by the GraphNodeImporter. This
89
+ # will exactly match the global in the target module we are merging into
90
+ # and exists so that the IR is valid during Fx import. We keep the set of
91
+ # symbols we have done this to here.
92
+ self .cloned_global_symbols : set [str ] = set ()
93
+
94
+ def resolve_literal (self , gni : GraphNodeImporter , literal : Any ) -> Optional [Value ]:
95
+ module_builder = self .module_builder
96
+ cloned_global_symbols = self .cloned_global_symbols
78
97
79
- def resolver (py_value : Any , gni : GraphNodeImporter ) -> Optional [Value ]:
80
98
# We support resolution of tracked reference types. Currently this
81
99
# only includes Tensors. All others we let the importer do what it
82
100
# is going to do.
83
- if not isinstance (py_value , torch .Tensor ):
101
+ if not isinstance (literal , torch .Tensor ):
84
102
return None
85
103
86
104
# See if we know about it.
87
- mapping = module_builder .global_ref_tracker .track (py_value )
105
+ mapping = module_builder .global_ref_tracker .track (literal )
88
106
if mapping .is_empty :
89
107
# If it is unknown, just let the default importer take it on.
90
108
return None
@@ -101,7 +119,7 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
101
119
cloned_global_symbols .add (materialized_global .symbol_name )
102
120
103
121
# Emit a global load and conversion.
104
- vtensor_type = gni ._cc .tensor_to_vtensor_type (py_value )
122
+ vtensor_type = gni ._cc .tensor_to_vtensor_type (literal )
105
123
loaded_value = util_d .GlobalLoadOp (
106
124
materialized_global .ir_type , materialized_global .symbol_name
107
125
).result
@@ -112,8 +130,6 @@ def resolver(py_value: Any, gni: GraphNodeImporter) -> Optional[Value]:
112
130
).result
113
131
return converted_value
114
132
115
- return resolver
116
-
117
133
118
134
ALL_PASSES : Set [str ] = set (["functorch_functionalize" ])
119
135
DEFAULT_PASSES : Tuple [str , ...] = ("functorch_functionalize" ,)
@@ -234,7 +250,7 @@ def flat_wrapped_f(*args):
234
250
fx_importer = FxImporter (
235
251
context = proc_trace .context ,
236
252
config_check = False ,
237
- literal_resolver_callback = _make_literal_resolver (proc_trace .module_builder ),
253
+ hooks = _Hooks (proc_trace .module_builder ),
238
254
py_attr_tracker = proc_trace .module_builder .fx_py_attr_tracker ,
239
255
)
240
256
fx_importer .import_stateless_graph (gm .graph , func_name = self .function_name )
0 commit comments