Skip to content

Commit 550d2dc

Browse files
authored
Add raw_ir output for pytorch (#4)
With the output type a user can specify their own way to print IR in the code and this output will be respected by the tool. Example use case: import torch import torch.nn as nn from torch_mlir import fx from torch_mlir.compiler_utils import run_pipeline_with_repro_report from torch_mlir.fx import OutputType ``` class MyModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(4, 4) def forward(self, x): return torch.relu(self.linear(x)) model = MyModel() example_input = torch.randn(4, 4) module = fx.export_and_import(model, example_input, output_type=OutputType.LINALG_ON_TENSORS) print(module) ``` This patch also brings few modifications to README.
2 parents 4f419f2 + ae59722 commit 550d2dc

File tree

4 files changed

+97
-8
lines changed

4 files changed

+97
-8
lines changed

README.md

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ tracing models through various IR stages and transformations.
2626

2727
- (PyTorch) The model and input tensor must be initialized in the provided code. If multiple models are defined, it is recommended to explicitly pair each model and its input tensor using the internal `__explore__(model, input)` function.
2828

29-
- (PyTorch) The current version does not recognize or capture user attempts to dump IR inside the input PyTorch module. It is planned that, in the future, if the user manually calls `fx.export_and_import()` (or similar IR-producing APIs), the app will use that IR as the base and apply the user-defined custom toolchain.
30-
3129
- (Triton) The current implementation runs Triton kernels and retrieves IR dumps from the Triton cache directory. Timeout is set to 20s.
3230

3331
## Getting Started
@@ -41,11 +39,13 @@ tracing models through various IR stages and transformations.
4139
- Triton
4240
- LLVM with mlir-opt
4341

44-
Current version is tested on Ubuntu 22.04 windows subsystem using LLVM 21 dev.
42+
To setup PyTorch and Torch-MLIR it's a good idea to visit https://github.com/llvm/torch-mlir repository and follow instructions from there.
43+
44+
Current version of the application is tested on Ubuntu 22.04 windows subsystem using LLVM 21 dev.
4545

4646
### Install dependencies
4747

48-
In case of missing prerequisites here are some scripts to help set them up.
48+
In case of missing prerequisites here are some scripts to help set them up (runs on Debian and its derivatives).
4949

5050
```bash
5151
git clone https://github.com/MrSidims/PytorchExplorer.git
@@ -74,6 +74,8 @@ If you want to use your builds of the tools like `torch-mlir-opt`, `mlir-opt` et
7474
npm run start:all
7575
```
7676

77+
Then open http://localhost:3000/ in your browser and enjoy!
78+
7779
### Run the tests
7880

7981
With the application (or just backend) started, run:
@@ -82,9 +84,14 @@ With the application (or just backend) started, run:
8284
pytest tests -v
8385
```
8486

87+
## User manual
88+
89+
TBD
90+
8591
## Implementation details
8692

87-
The app uses fx.export_and_import to inpect IR output for PyTorch.
93+
The app uses `fx.export_and_import` under the hood to inpect IR output for PyTorch, therefore for pre-defined lowering paths it's required for a module to have `forward` method.
94+
8895
Lowering to LLVM IR goes through:
8996

9097
```bash

backend/server.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
import atexit
2626
import shutil
2727

28+
import io
29+
import sys
30+
from contextlib import redirect_stdout, redirect_stderr
31+
2832
app = FastAPI()
2933

3034
app.add_middleware(
@@ -44,6 +48,7 @@
4448
class CodeRequest(BaseModel):
4549
code: str
4650
ir_type: str
51+
selected_language: Optional[str] = "pytorch"
4752
custom_pipeline: Optional[List[str]] = []
4853
torch_mlir_opt: Optional[str] = ""
4954
mlir_opt: Optional[str] = ""
@@ -60,15 +65,20 @@ class FreeIRCacheRequest(BaseModel):
6065

6166
# Get model-tensor pairs to process.
6267
def extract_model_input_pairs(code: str):
63-
import traceback
64-
6568
explore_pairs = []
6669

6770
def __explore__(model, input_tensor):
6871
explore_pairs.append((model, input_tensor))
6972

7073
exec_globals = {"__explore__": __explore__}
71-
exec(code, exec_globals)
74+
75+
# Suppress stdout and stderr.
76+
with open(os.devnull, "w") as devnull:
77+
try:
78+
with redirect_stdout(devnull), redirect_stderr(devnull):
79+
exec(code, exec_globals)
80+
except Exception:
81+
pass
7282

7383
# If no __explore__ calls found, try matching models and tensors heuristically.
7484
if not explore_pairs:
@@ -429,6 +439,26 @@ def process_model(request: CodeRequest) -> str:
429439
if request.ir_type.startswith("triton"):
430440
return compile_triton_ir(request.code, request.ir_type)
431441

442+
if request.ir_type == "raw_ir" and request.selected_language == "pytorch":
443+
# Execute user Python, capture stdout.
444+
try:
445+
with tempfile.TemporaryDirectory() as tmpdir:
446+
stdout_path = os.path.join(tmpdir, "captured_output.txt")
447+
with open(stdout_path, "w") as f, redirect_stdout(
448+
f
449+
), redirect_stderr(f):
450+
exec_globals = {}
451+
exec(request.code, exec_globals)
452+
453+
with open(stdout_path, "r") as f:
454+
captured = f.read()
455+
456+
return apply_optional_passes(
457+
captured, build_pipeline(request), request.dump_after_each_opt
458+
)
459+
except Exception as e:
460+
return f"Error executing user code: {str(e)}"
461+
432462
if request.ir_type == "raw_ir":
433463
return apply_optional_passes(
434464
request.code,

src/app/page.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ const pytorchIROptions = [
7878
{ value: "stablehlo_mlir", label: "StableHLO MLIR" },
7979
{ value: "llvm_mlir", label: "LLVM MLIR" },
8080
{ value: "llvm_ir", label: "LLVM IR" },
81+
{ value: "raw_ir", label: "Raw IR Output" },
8182
];
8283

8384
const tritonIROptions = [
@@ -308,6 +309,7 @@ export default function PyTorchTritonExplorer() {
308309
const body = {
309310
code,
310311
ir_type: window.selectedIR,
312+
selected_language: selectedLanguage,
311313
custom_pipeline: [],
312314
torch_mlir_opt: window.pipeline
313315
.filter((p) => p.tool === "torch-mlir-opt")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import httpx
3+
4+
API_URL = "http://localhost:8000/generate_ir"
5+
6+
7+
def test_user_controlled_ir_print():
8+
code = """
9+
import torch
10+
import torch.nn as nn
11+
from torch_mlir import fx
12+
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
13+
from torch_mlir.fx import OutputType
14+
15+
class MyModel(nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
self.linear = nn.Linear(4, 4)
19+
20+
def forward(self, x):
21+
return torch.relu(self.linear(x))
22+
23+
model = MyModel()
24+
example_input = torch.randn(4, 4)
25+
module = fx.export_and_import(model, example_input, output_type=OutputType.LINALG_ON_TENSORS)
26+
print(module)
27+
"""
28+
29+
payload = {
30+
"code": code,
31+
"ir_type": "raw_ir",
32+
"custom_pipeline": [],
33+
"torch_mlir_opt": "",
34+
"mlir_opt": '--one-shot-bufferize="bufferize-function-boundaries"',
35+
"mlir_translate": "",
36+
"llvm_opt": "",
37+
"llc": "",
38+
"user_tool": "",
39+
"dump_after_each_opt": False,
40+
}
41+
42+
response = httpx.post(API_URL, json=payload)
43+
assert response.status_code == 200
44+
45+
ir = response.json()["output"]
46+
47+
assert "affine_map" in ir
48+
assert "memref.global" in ir
49+
assert "arith.constant" in ir
50+
assert "linalg.matmul" in ir

0 commit comments

Comments
 (0)