Skip to content

Commit 8328998

Browse files
authored
Allow printing all IR in torch_mlir.compile (#2669)
This PR adds the `enable_ir_printing` option to `torch_mlir.compile`, which can be used to print the IR for all intermediate passes. When running the added test file via: ```shell $ python test/python/compile.py 2> tiny.stderr ``` the file `tiny.stderr` is about 700 KB.
1 parent 11cc92d commit 8328998

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

projects/pt1/python/torch_mlir/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def compile(model: torch.nn.Module,
319319
backend_legal_ops: Optional[Sequence[str]] = None,
320320
extra_library: Iterable[Callable] = [],
321321
verbose: bool = False,
322-
use_make_fx: bool = False):
322+
use_make_fx: bool = False,
323+
enable_ir_printing: bool = False):
323324
"""Convert a PyTorch model to MLIR.
324325
325326
Args:
@@ -348,7 +349,13 @@ def compile(model: torch.nn.Module,
348349
into the abstract interpretation library. See
349350
`docs/adding_abstract_interpretation_functions.md` for more info
350351
on the format the functions should have.
351-
verbose: If true, print extra information about the conversion.
352+
verbose: If true, print extra information about the conversion to
353+
stdout.
354+
enable_ir_printing: If true, print the IR before and after each pass to
355+
stderr. This is equivalent to setting MLIR's `-print-ir-after-all`
356+
flag. Note that this can easily generate many gigabytes of text,
357+
so make sure to pipe stderr to a file (for example, run
358+
`python tinymodel.py 2> tinymodel.stderr` on Linux).
352359
353360
Returns:
354361
An MLIR module that contains the converted model in the specified
@@ -452,6 +459,7 @@ def compile(model: torch.nn.Module,
452459
mb.module,
453460
f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})",
454461
"Lowering TorchScript IR -> Torch Backend IR",
462+
enable_ir_printing=enable_ir_printing,
455463
)
456464

457465
return _lower_mlir_module(verbose, output_type, mb.module)

projects/pt1/python/torch_mlir/compiler_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class TorchMlirCompilerError(Exception):
2727

2828
def run_pipeline_with_repro_report(module,
2929
pipeline: str,
30-
description: str):
30+
description: str,
31+
enable_ir_printing: bool = False):
3132
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
3233
module_name = get_module_name_for_debug_dump(module)
3334
try:
@@ -36,8 +37,11 @@ def run_pipeline_with_repro_report(module,
3637
asm_for_error_report = module.operation.get_asm(
3738
large_elements_limit=10, enable_debug_info=True)
3839
# Lower module in place to make it ready for compiler backends.
39-
with module.context:
40+
with module.context as ctx:
4041
pm = PassManager.parse(pipeline)
42+
if enable_ir_printing:
43+
ctx.enable_multithreading(False)
44+
pm.enable_ir_printing()
4145
pm.run(module.operation)
4246
except Exception as e:
4347
# TODO: More robust.

test/python/compile.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# RUN: %PYTHON -s %s 2>&1 | FileCheck %s
2+
3+
import gc
4+
import sys
5+
import torch
6+
import torch_mlir
7+
8+
9+
def run_test(f):
10+
print("TEST:", f.__name__, file=sys.stderr)
11+
f()
12+
gc.collect()
13+
14+
15+
class TinyModel(torch.nn.Module):
16+
def __init__(self):
17+
super(TinyModel, self).__init__()
18+
19+
self.linear = torch.nn.Linear(20, 30)
20+
21+
def forward(self, x):
22+
x = self.linear(x)
23+
return x
24+
25+
26+
# CHECK-LABEL: TEST: test_enable_ir_printing
27+
@run_test
28+
def test_enable_ir_printing():
29+
torch_mlir.compile(TinyModel(),
30+
torch.ones(1, 3, 20, 20),
31+
output_type="linalg-on-tensors",
32+
enable_ir_printing=True)
33+
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize)
34+
# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} {

0 commit comments

Comments
 (0)