Skip to content

Commit d8eedcf

Browse files
committed
Add pre-defined lowering to NVPTX, AMDGPU and SPIR-V for PyTorch
Originally I wasn't planning to add them, as it's possible to just call llc. On the other hand I started building this tool not only for engineers familiar with MLIR compiler toolchain(s), but also novices (like myself). And from this perspective it might make sense to have (almost) final lowering for target GPU.
1 parent c7bf63a commit d8eedcf

File tree

3 files changed

+123
-3
lines changed

3 files changed

+123
-3
lines changed

backend/server.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def generate_llvm_mlir(
338338
# First generate LLVM MLIR and then translate it to LLVM IR.
339339
def generate_llvm_ir(
340340
model, example_input, pipeline: List[Tuple[str, str]], dump_each: bool
341-
):
341+
) -> str:
342342
try:
343343
lowered_mlir = lower_to_llvm_mlir(model, example_input)
344344

@@ -365,6 +365,23 @@ def generate_llvm_ir(
365365
raise IRGenerationError("Failed to generate LLVM IR.")
366366

367367

368+
# Generate NVPTX, AMDGPU or SPIR-V.
369+
def generate_target_gpu_ir(model, example_input, target: str) -> str:
370+
try:
371+
llvm_ir_module = generate_llvm_ir(model, example_input, [], False)
372+
pipeline: list[tuple[str, str]] = [("opt", "-O2")]
373+
if target == "nvptx":
374+
pipeline.append(("llc", "-mtriple=nvptx64-nvidia-cuda"))
375+
elif target == "amdgpu":
376+
pipeline.append(("llc", "-mtriple amdgcn-amd-amdhsa"))
377+
elif target == "spirv":
378+
pipeline.append(("llc", "-mtriple=spirv64-unknown-unknown"))
379+
return apply_optional_passes(llvm_ir_module, pipeline, False)
380+
except Exception as e:
381+
logger.exception("Error generating LLVM IR.")
382+
raise IRGenerationError("Failed to generate LLVM IR.")
383+
384+
368385
# TODO: Figure out static compilation.
369386
def compile_triton_ir(
370387
code: str, ir_type: str, pipeline: List[Tuple[str, str]], dump_each: bool
@@ -547,6 +564,14 @@ def process_model(request: CodeRequest) -> str:
547564
combined_output += generate_llvm_ir(
548565
model, example_input, pipeline, request.dump_after_each_opt
549566
)
567+
elif request.ir_type in ("nvptx", "amdgpu", "spirv"):
568+
# FIXME?: it could really be just generate_llvm_ir with the pipeline.
569+
# Yet I prefered a dedicated function in case of some smart things has
570+
# to be done before lowering. For example for SPIR-V it's nice idea
571+
# to create a kernel first aka write a pass and execute it here.
572+
combined_output += generate_target_gpu_ir(
573+
model, example_input, request.ir_type
574+
)
550575
else:
551576
combined_output += "IR type not supported yet."
552577

src/app/page.js

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ const pytorchIROptions = [
7979
{ value: "stablehlo_mlir", label: "StableHLO MLIR" },
8080
{ value: "llvm_mlir", label: "LLVM MLIR" },
8181
{ value: "llvm_ir", label: "LLVM IR" },
82+
{ value: "nvptx", label: "NVPTX" },
83+
{ value: "amdgpu", label: "AMDGPU" },
84+
{ value: "spirv", label: "SPIR-V" },
8285
{ value: "raw_ir", label: "Raw IR Output" },
8386
];
8487

@@ -756,8 +759,12 @@ export default function PyTorchTritonExplorer() {
756759
"triton_gpu_ir",
757760
"triton_llvm_ir",
758761
].includes(irWin.selectedIR);
759-
const allowLlvmOpt =
760-
irWin.selectedIR !== "triton_nvptx";
762+
const allowLlvmOpt = ![
763+
"triton_nvptx",
764+
"nvptx",
765+
"amdgpu",
766+
"spirv",
767+
].includes(irWin.selectedIR);
761768
const allowLLC = allowLlvmOpt;
762769
const allowUserTool = true;
763770

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
import httpx
3+
4+
API_URL = "http://localhost:8000/generate_ir"
5+
6+
code = """
7+
import torch
8+
import torch.nn as nn
9+
10+
class MyModel(nn.Module):
11+
def __init__(self):
12+
super().__init__()
13+
self.linear = nn.Linear(4, 4)
14+
15+
def forward(self, x):
16+
return torch.relu(self.linear(x))
17+
18+
model = MyModel()
19+
example_input = torch.randn(4, 4)
20+
"""
21+
22+
23+
def test_torch_nvptx_linear():
24+
payload = {
25+
"code": code,
26+
"ir_type": "nvptx",
27+
"custom_pipeline": [],
28+
"torch_mlir_opt": "",
29+
"mlir_opt": "",
30+
"mlir_translate": "",
31+
"llvm_opt": "",
32+
"llc": "",
33+
"user_tool": "",
34+
"dump_after_each_opt": False,
35+
}
36+
37+
response = httpx.post(API_URL, json=payload)
38+
assert response.status_code == 200
39+
40+
ir = response.json()["output"]
41+
42+
assert "Generated by LLVM NVPTX Back-End" in ir
43+
assert ".visible .func (.param .align 8 .b8 func_retval0[56]) main" in ir
44+
45+
46+
def test_torch_nvptx_linear():
47+
payload = {
48+
"code": code,
49+
"ir_type": "amdgpu",
50+
"custom_pipeline": [],
51+
"torch_mlir_opt": "",
52+
"mlir_opt": "",
53+
"mlir_translate": "",
54+
"llvm_opt": "",
55+
"llc": "",
56+
"user_tool": "",
57+
"dump_after_each_opt": False,
58+
}
59+
60+
response = httpx.post(API_URL, json=payload)
61+
assert response.status_code == 200
62+
63+
ir = response.json()["output"]
64+
65+
assert "amdgcn-amd-amdhsa--gfx700" in ir
66+
assert "main:" in ir
67+
68+
69+
def test_torch_nvptx_linear():
70+
payload = {
71+
"code": code,
72+
"ir_type": "spirv",
73+
"custom_pipeline": [],
74+
"torch_mlir_opt": "",
75+
"mlir_opt": "",
76+
"mlir_translate": "",
77+
"llvm_opt": "",
78+
"llc": "",
79+
"user_tool": "",
80+
"dump_after_each_opt": False,
81+
}
82+
83+
response = httpx.post(API_URL, json=payload)
84+
assert response.status_code == 200
85+
86+
ir = response.json()["output"]
87+
88+
assert "OpCapability Kernel" in ir

0 commit comments

Comments
 (0)