Skip to content

Commit 4c6e463

Browse files
authored
[test] refine fx_importer_backend with verbose option (#4084)
* make `verbose` (aka `-v`) option's behavior to be same with other backends. `verbose=True` would print the IR of each stage. * clear `verbose` and `enable_ir_printing` option.
1 parent c7f8ac0 commit 4c6e463

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False
4242
def compile(
4343
self, program: torch.nn.Module, verbose: bool = False
4444
) -> torch.nn.Module:
45+
self._verbose = verbose
4546
return program
4647

4748
def run(self, artifact: torch.nn.Module, trace: Trace):
@@ -84,6 +85,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
8485
gm,
8586
output_type=self._output_type,
8687
model_name=artifact.__class__.__name__,
88+
verbose=self._verbose,
8789
)
8890
module = self._backend.compile(module)
8991
backend_module = self._backend.load(module)
@@ -128,6 +130,7 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
128130
# While the current e2e tests don't exercise symbolic shapes,
129131
# enabling this here ensures they don't regress either.
130132
import_symbolic_shape_expressions=True,
133+
verbose=self._verbose,
131134
)
132135
module = self._backend.compile(module)
133136
backend_module = self._backend.load(module)

python/torch_mlir/fx.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@
2626

2727
def _module_lowering(
2828
verbose,
29+
enable_ir_printing,
2930
output_type,
3031
torch_mod,
3132
extra_library_file_name=None,
3233
backend_legal_ops=None,
3334
):
35+
if verbose:
36+
print("\n====================")
37+
print("TorchFX IR")
38+
print(torch_mod)
3439

3540
if output_type == OutputType.RAW:
36-
if verbose:
37-
print(torch_mod)
3841
return torch_mod
3942
# TODO: pass extra_library_file_name by caller
4043

@@ -59,7 +62,7 @@ def _module_lowering(
5962
torch_mod,
6063
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
6164
"Lowering TorchFX IR -> Torch Backend IR",
62-
enable_ir_printing=verbose,
65+
enable_ir_printing=enable_ir_printing,
6366
)
6467
return lower_mlir_module(verbose, output_type, torch_mod)
6568

@@ -76,6 +79,7 @@ def export_and_import(
7679
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None,
7780
func_name: str = "main",
7881
enable_graph_printing: bool = False,
82+
verbose: bool = False,
7983
enable_ir_printing: bool = False,
8084
backend_legal_ops: Optional[list[str]] = None,
8185
**kwargs,
@@ -115,6 +119,7 @@ def export_and_import(
115119
)
116120

117121
return _module_lowering(
122+
verbose,
118123
enable_ir_printing,
119124
OutputType.get(output_type),
120125
fx_importer.module,
@@ -129,6 +134,7 @@ def stateless_fx_import(
129134
hooks: Optional[FxImporterHooks] = None,
130135
model_name: str = "main",
131136
enable_graph_printing: bool = False,
137+
verbose: bool = False,
132138
enable_ir_printing: bool = False,
133139
backend_legal_ops: Optional[list[str]] = None,
134140
):
@@ -140,6 +146,7 @@ def stateless_fx_import(
140146
fx_importer = FxImporter(context=context, hooks=hooks)
141147
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
142148
return _module_lowering(
149+
verbose,
143150
enable_ir_printing,
144151
OutputType.get(output_type),
145152
fx_importer.module,

0 commit comments

Comments
 (0)