1515)
1616
1717from iree .turbine .aot import *
18+ # from iree.turbine.aot import FxProgramsBuilder, export
1819
1920DEFAULT_COMPILE_FLAGS = [
2021 "--iree-hal-target-device=hip" , # change to your backend (e.g., local, cuda, vulkan)
@@ -38,40 +39,38 @@ def _as_tuple(x):
3839 return tuple (x )
3940 return (x ,)
4041
41- def run_iree_vs_torch_fx (
42+ def export_torch_module_to_mlir (
4243 module : torch .nn .Module ,
4344 args = (),
4445 kwargs = None ,
4546 * ,
46- atol = 1e-4 ,
47- rtol = 0.0 ,
48- entrypoint = "forward" ,
49- parameters_path = None ,
47+ mlir_path : Path ,
48+ target_fn = "run_forward" ,
5049):
5150 """
52- Exports MLIR via FxProgramsBuilder(model) and compares IREE vs Torch eager .
51+ Export torch module to MLIR and get torch eager reference output .
5352
5453 Args:
55- module: torch.nn.Module under test
56- args: example positional inputs (tuple required)
57- kwargs: example kwargs
58- atol/rtol: tolerances passed to torch.testing.assert_close
59- entrypoint: the method name exported/invoked ("forward" by default)
54+ module: torch.nn.Module under test
55+ args: example positional inputs (tuple required)
56+ kwargs: example kwargs
57+ mlir_path: Path where to save the MLIR file
58+ target_fn: name of the exported function
59+
60+ Returns:
61+ Tuple of (torch_eager_output, export_output)
6062 """
6163 kwargs = kwargs or {}
6264 args = _as_tuple (args )
6365 torch .manual_seed (1234 )
64- target_fn = "run_forward"
65- entrypoint = target_fn
6666
67- # ---- 1) Torch eager reference ----
67+ # ---- Torch eager reference ----
6868 module .eval ()
6969 with torch .no_grad ():
7070 expected = module (* args , ** kwargs )
7171
7272 fxb = FxProgramsBuilder (module )
7373
74-
7574 # empty tensors for export input
7675 # there needs to be one corresponding to each arg
7776 # NOTE: assuming args are not nested.
@@ -95,55 +94,99 @@ def run_iree_vs_torch_fx(
9594 def _ (module , * fn_args ):
9695 return module .forward (* fn_args )
9796
98- # Export the selected entry point (callable) from the instance `module`.
99- # We pass a bound method so export() can trace that entry.
100- # target_fn = getattr(type(module), entrypoint)
101-
10297 export_output = export (fxb , import_symbolic_shape_expressions = True )
98+ export_output .save_mlir (mlir_path )
10399
104- # The turbine builder attaches a Torch-MLIR operation on the exported program.
105- # Retrieve MLIR text and compile it with iree-compile.
106- # Note: sharktank's exporter uses the same fx-builder object to drive MLIR generation.
107- # See export_paged_llm_v1.py (fxb usage).
108- # mlir_text = ep.mlir_module_operation.get_asm(enable_debug_info=False)
100+ return expected , export_output
109101
110- # Compile MLIR -> VMFB
111- with tempfile .TemporaryDirectory () as td :
112- td = Path (td )
113- mlir_path = td / "module.mlir"
114- vmfb_path = td / "module.vmfb"
115- export_output .save_mlir (mlir_path )
116102
117- iree .compiler .compile_file (
118- str (mlir_path ),
119- output_file = str (vmfb_path ),
120- extra_args = DEFAULT_COMPILE_FLAGS ,
121- )
103+ def compile_mlir_to_vmfb (
104+ mlir_path : Path ,
105+ vmfb_path : Path ,
106+ * ,
107+ compile_flags = None ,
108+ ):
109+ """
110+ Compile MLIR file to VMFB.
111+
112+ Args:
113+ mlir_path: Path to the MLIR file
114+ vmfb_path: Path where to save the VMFB file
115+ compile_flags: List of compilation flags (uses DEFAULT_COMPILE_FLAGS if None)
116+ """
117+ compile_flags = compile_flags or DEFAULT_COMPILE_FLAGS
118+
119+ iree .compiler .compile_file (
120+ str (mlir_path ),
121+ output_file = str (vmfb_path ),
122+ extra_args = compile_flags ,
123+ )
124+
125+
126+ def run_iree_module_from_vmfb (
127+ vmfb_path : Path ,
128+ args = (),
129+ * ,
130+ entrypoint = "run_forward" ,
131+ parameters_path = None ,
132+ driver = "hip" ,
133+ device_count = 1 ,
134+ ):
135+ """
136+ Load VMFB and run with IREE.
137+
138+ Args:
139+ vmfb_path: Path to the VMFB file
140+ args: Input arguments for the module
141+ entrypoint: Name of the function to run
142+ parameters_path: Optional path to parameters file
143+ driver: IREE driver to use
144+ device_count: Number of devices
145+
146+ Returns:
147+ IREE module output
148+ """
149+ args = _as_tuple (args )
150+
151+ # Load & run with IREE
152+ devices = get_iree_devices (driver = driver , device_count = device_count )
153+ iree_module , vm_context , _ = load_iree_module (
154+ module_path = str (vmfb_path ),
155+ devices = devices ,
156+ parameters_path = parameters_path ,
157+ )
158+ iree_args = prepare_iree_module_function_args (args = args , devices = devices )
159+
160+ iree_out = run_iree_module_function (
161+ module = iree_module ,
162+ vm_context = vm_context ,
163+ args = iree_args ,
164+ device = devices [0 ],
165+ function_name = entrypoint ,
166+ )
167+
168+ return iree_out
122169
123- # Load & run with IREE
124- devices = get_iree_devices (driver = "hip" , device_count = 1 ) # adjust driver
125- iree_module , vm_context , _ = load_iree_module (
126- module_path = str (vmfb_path ),
127- devices = devices ,
128- parameters_path = parameters_path ,
129- )
130- iree_args = prepare_iree_module_function_args (args = args , devices = devices )
131-
132- # For FxProgramsBuilder export, the function name is typically "forward".
133- # If you exported a different method, pass entrypoint=<that name>.
134- # do we need logic to identify the correct entrypoints, will we have multi entry point executions in these pytests?
135-
136- iree_out = run_iree_module_function (
137- module = iree_module ,
138- vm_context = vm_context ,
139- args = iree_args ,
140- device = devices [0 ],
141- function_name = entrypoint ,
142- )
143170
144- # TODO: refactor to separate it from iree compile and run
171+ def compare_iree_torch_outputs (
172+ iree_output ,
173+ torch_output ,
174+ * ,
175+ atol = 1e-4 ,
176+ rtol = 0.0 ,
177+ ):
178+ """
179+ Compare IREE output with torch eager reference and assert closeness.
180+
181+ Args:
182+ iree_output: Output from IREE module
183+ torch_output: Output from torch eager execution
184+ atol/rtol: tolerances passed to torch.testing.assert_close
185+ """
145186 # Convert and compare
146- actual = iree_to_torch (* iree_out )
187+ actual = iree_to_torch (* iree_output )
188+ expected = torch_output
189+
147190 if isinstance (expected , torch .Tensor ):
148191 expected = (expected ,)
149192 if isinstance (actual , torch .Tensor ):
@@ -154,3 +197,70 @@ def _(module, *fn_args):
154197 torch .testing .assert_close (actual , expected , atol = atol , rtol = rtol )
155198 print (f"actual : { actual } " )
156199 print (f"expected : { expected } " )
200+
201+
202+ def run_iree_vs_torch_fx (
203+ module : torch .nn .Module ,
204+ args = (),
205+ kwargs = None ,
206+ * ,
207+ atol = 1e-4 ,
208+ rtol = 0.0 ,
209+ entrypoint = "run_forward" ,
210+ parameters_path = None ,
211+ compile_flags = None ,
212+ driver = "hip" ,
213+ device_count = 1 ,
214+ ):
215+ """
216+ Wrapper for MLIR export via FxProgramsBuilder(model) and IREE vs Torch eager comparison.
217+
218+ Args:
219+ module: torch.nn.Module under test
220+ args: example positional inputs (tuple required)
221+ kwargs: example kwargs
222+ atol/rtol: tolerances passed to torch.testing.assert_close
223+ entrypoint: the method name exported/invoked ("run_forward" by default)
224+ parameters_path: Optional path to parameters file
225+ compile_flags: List of compilation flags (uses DEFAULT_COMPILE_FLAGS if None)
226+ driver: IREE driver to use
227+ device_count: Number of devices
228+ """
229+ with tempfile .TemporaryDirectory () as td :
230+ td = Path (td )
231+ mlir_path = td / "module.mlir"
232+ vmfb_path = td / "module.vmfb"
233+
234+ # Export to MLIR and get torch reference
235+ torch_output , _ = export_torch_module_to_mlir (
236+ module = module ,
237+ args = args ,
238+ kwargs = kwargs ,
239+ mlir_path = mlir_path ,
240+ target_fn = entrypoint ,
241+ )
242+
243+ # Compile MLIR to VMFB
244+ compile_mlir_to_vmfb (
245+ mlir_path = mlir_path ,
246+ vmfb_path = vmfb_path ,
247+ compile_flags = compile_flags ,
248+ )
249+
250+ # Run with IREE
251+ iree_output = run_iree_module_from_vmfb (
252+ vmfb_path = vmfb_path ,
253+ args = args ,
254+ entrypoint = entrypoint ,
255+ parameters_path = parameters_path ,
256+ driver = driver ,
257+ device_count = device_count ,
258+ )
259+
260+ # Compare outputs
261+ compare_iree_torch_outputs (
262+ iree_output = iree_output ,
263+ torch_output = torch_output ,
264+ atol = atol ,
265+ rtol = rtol ,
266+ )
0 commit comments