88import torch .utils ._pytree as pytree
99from torch .export .graph_signature import OutputSpec , OutputKind
1010from torch .export import ExportedProgram
11+ from torch ._dynamo .backends .common import aot_autograd
1112
1213from torch_mlir import fx
1314from torch_mlir_e2e_test .configs .utils import (
1415 recursively_convert_to_numpy ,
1516 recursively_convert_from_numpy ,
1617)
1718from torch_mlir_e2e_test .framework import TestConfig , Trace , TraceItem
19+ from torch_mlir_e2e_test .annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
1820
1921
2022def refine_result_type (_result ):
@@ -31,17 +33,91 @@ def refine_result_type(_result):
3133class FxImporterTestConfig (TestConfig ):
3234 """TestConfig that runs the torch.nn.Module with Fx Importer"""
3335
34- def __init__ (self , backend , output_type = "linalg-on-tensors" ):
36+ def __init__ (self , backend , output_type = "linalg-on-tensors" , torch_compile = False ):
3537 super ().__init__ ()
3638 self ._backend = backend
39+ self ._torch_compile = torch_compile
3740 self ._output_type = output_type
3841
3942 def compile (
4043 self , program : torch .nn .Module , verbose : bool = False
4144 ) -> torch .nn .Module :
4245 return program
4346
44- def run (self , artifact : torch .nn .Module , trace : Trace ) -> Trace :
47+ def run (self , artifact : torch .nn .Module , trace : Trace ):
48+ return (
49+ self ._export_run (artifact , trace )
50+ if not self ._torch_compile
51+ else self ._stateless_run (artifact , trace )
52+ )
53+
54+ def _stateless_run (self , artifact : torch .nn .Module , trace : Trace ):
55+ dynamic_argument_pos = None
56+ dynamic_dim_pos = None
57+ annotations = getattr (artifact .forward , TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME )
58+ for i , annotation in enumerate (annotations ):
59+ if i == 0 : # Skip the "self" annotation.
60+ continue
61+ if not annotation [2 ]:
62+ raise ValueError (
63+ "Can only compile inputs annotated as having value semantics."
64+ )
65+ for dim_i , dim in enumerate (annotation [0 ]):
66+ if dim == - 1 :
67+ dynamic_argument_pos = i - 1
68+ dynamic_dim_pos = dim_i
69+ break
70+ if dynamic_argument_pos is not None :
71+ break
72+ result : Trace = []
73+ for item in trace :
74+
75+ def _base_backend (gm : torch .fx .GraphModule , example_inputs ):
76+ for node in gm .graph .nodes :
77+ if node .op == "placeholder" :
78+ if (
79+ isinstance (node .meta ["val" ], torch .SymInt )
80+ and not node .users
81+ ):
82+ gm .graph .erase_node (node )
83+ module = fx .stateless_fx_import (
84+ gm ,
85+ output_type = self ._output_type ,
86+ model_name = artifact .__class__ .__name__ ,
87+ )
88+ module = self ._backend .compile (module )
89+ backend_module = self ._backend .load (module )
90+
91+ def invoke_func (* torch_inputs ):
92+ torch_inputs = [
93+ x
94+ for x in filter (
95+ lambda i : isinstance (i , torch .Tensor ), torch_inputs
96+ )
97+ ]
98+ with torch .no_grad ():
99+ numpy_inputs = recursively_convert_to_numpy (torch_inputs )
100+ return recursively_convert_from_numpy (
101+ getattr (backend_module , artifact .__class__ .__name__ )(
102+ * numpy_inputs
103+ )
104+ )
105+
106+ return invoke_func
107+
108+ fw_compiler = aot_autograd (fw_compiler = _base_backend )
109+ if dynamic_argument_pos is not None :
110+ torch ._dynamo .mark_dynamic (
111+ item .inputs [dynamic_argument_pos ], dynamic_dim_pos
112+ )
113+ module = torch .compile (artifact , backend = fw_compiler )
114+ outputs = module (* item .inputs )
115+ result .append (
116+ TraceItem (symbol = item .symbol , inputs = item .inputs , output = outputs )
117+ )
118+ return result
119+
120+ def _export_run (self , artifact : torch .nn .Module , trace : Trace ) -> Trace :
45121 result : Trace = []
46122 for item in trace :
47123 prog : ExportedProgram = torch .export .export (artifact , tuple (item .inputs ))
0 commit comments