Skip to content

Commit af5daad

Browse files
authored
Finalize support for Triton (#42)
Triton installation is a bit tricky as it's mandatory to have either NVIDIA or AMD backend installed on the system (which can also be different from version-to-version and be hardware dependent. So here no automatic installation is provided with the commit. Instead PYTORCH_INDEX env variable is introduced to change default CPU pytorch nightly wheels to whatever a user desires.
1 parent 88924af commit af5daad

File tree

8 files changed

+74
-28
lines changed

8 files changed

+74
-28
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ RUN python3 -m venv /opt/venv && \
4646
/opt/venv/bin/pip install --pre torch-mlir torchvision \
4747
--extra-index-url=https://download.pytorch.org/whl/nightly/cpu \
4848
-f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels && \
49-
/opt/venv/bin/pip install fastapi uvicorn pytest httpx
49+
/opt/venv/bin/pip install triton fastapi uvicorn pytest httpx
5050

5151
# Create non-root user and fix permissions
5252
RUN useradd -u 10001 -m --shell /usr/sbin/nologin appuser && \

Dockerfile.backend

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ RUN apt-get update && \
1515
RUN useradd -u 10001 -m --shell /usr/sbin/nologin appuser && \
1616
mkdir -p /home/appuser/.cache
1717

18-
RUN wget -qO- https://apt.llvm.org/llvm.sh | bash -s -- 21 && \
18+
RUN wget -qO- https://apt.llvm.org/llvm.sh | bash -s -- 22 && \
1919
apt-get update && \
2020
apt-get install -y --no-install-recommends \
21-
libmlir-21-dev mlir-21-tools && \
21+
libmlir-22-dev mlir-22-tools && \
2222
rm -rf /var/lib/apt/lists/*
2323

2424
COPY --chown=10001:10001 backend /app/backend
@@ -28,13 +28,13 @@ RUN python3 -m venv /opt/venv && \
2828
/opt/venv/bin/pip install --pre torch-mlir torchvision \
2929
--extra-index-url=https://download.pytorch.org/whl/nightly/cpu \
3030
-f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels && \
31-
/opt/venv/bin/pip install fastapi uvicorn pydantic
31+
/opt/venv/bin/pip install triton fastapi uvicorn pydantic
3232

3333
RUN chown -R appuser:appuser /home/appuser/.cache /app
3434

3535
USER appuser
3636

37-
ENV PATH="/opt/venv/bin:/usr/lib/llvm-21/bin:$PATH"
37+
ENV PATH="/opt/venv/bin:/usr/lib/llvm-22/bin:$PATH"
3838

3939
EXPOSE 8000
4040
CMD ["uvicorn", "backend.server:app", "--host", "0.0.0.0", "--port", "8000"]

README.md

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,19 @@ tracing models through various IR stages and transformations.
3838
- Node.js + npm
3939
- PyTorch
4040
- Torch-MLIR
41-
- Triton
4241
- LLVM with mlir-opt
42+
- Triton
4343

4444
To setup PyTorch and Torch-MLIR it's a good idea to visit https://github.com/llvm/torch-mlir repository and follow instructions from there.
4545

4646
Current version of the application is tested on Ubuntu 22.04 windows subsystem using LLVM 22 dev.
4747

4848
Triton requires that PyTorch be compiled with CUDA or ROCm support. When
4949
installing PyTorch, pick the desired accelerator build. For example, to install
50-
a CUDA 12.4 wheel you can run (note: this is not included in scripts and dockerfiles):
50+
a CUDA 12.8 wheel you can run (note: this is not included in scripts and dockerfiles) (at least this works with my Blackwell GPU):
5151

5252
```bash
53-
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu124
53+
pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/cu128
5454
```
5555

5656
### Install dependencies
@@ -61,12 +61,24 @@ git clone https://github.com/MrSidims/PytorchExplorer.git
6161
cd PytorchExplorer
6262
```
6363

64+
To use custom builds of `torch-mlir-opt`, `mlir-opt`, etc. without placing them in your `$PATH`, configure the following environment variables:
65+
- `TORCH_MLIR_OPT_PATH`
66+
- `LLVM_BIN_PATH`
67+
- `TRITON_OPT_PATH`
68+
- `PYTORCH_INDEX` – Index URL for installing PyTorch. Defaults to nightly CPU wheels.
69+
70+
For example, to install CUDA-enabled nightly wheels (CUDA 12.8):
71+
```bash
72+
PYTORCH_INDEX=https://download.pytorch.org/whl/nightly/cu128 \
73+
source setup_backend.sh
74+
```
75+
6476
Install frontend dependencies:
6577
```bash
6678
source setup_frontend.sh
6779
```
6880

69-
Set up backend (Torch, MLIR, etc.):
81+
Set up backend (Torch, MLIR, etc.) (note, unless `PYTORCH_INDEX` is set the script will install CPU wheels):
7082
```bash
7183
source setup_backend.sh
7284
```
@@ -76,11 +88,6 @@ If you already have a working venv for Torch-MLIR, you can just install FastAPI
7688
pip install fastapi uvicorn pytest httpx
7789
```
7890

79-
To use custom builds of `torch-mlir-opt`, `mlir-opt`, etc. without placing them in your `$PATH`, configure the following environment variables:
80-
- `TORCH_MLIR_OPT_PATH`
81-
- `LLVM_BIN_PATH`
82-
- `TRITON_OPT_PATH`
83-
8491
### Run the application
8592

8693
If you are reused `setup_backend.sh` script - activate the environment with
@@ -220,5 +227,4 @@ For more details about IR lowering, please see [PyTorch Lowerings](docs/pytorch_
220227

221228
## Integration with your frontend or backend
222229

223-
Refer to the [Integration Guide](docs/integration_guide.md) for details on the API contracts and communication between the frontend and backend used in this project.
224-
230+
Refer to the [Integration Guide](docs/integration_guide.md) for details on the API contracts and communication between the frontend and backend used in this project.

backend/server.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import subprocess
22
import tempfile
33
import os
4+
import linecache
45
import glob
56
import uuid
67
import hashlib
@@ -413,7 +414,7 @@ def generate_target_gpu_ir(model, example_input, target: str) -> str:
413414
raise IRGenerationError(f"Failed to generate LLVM IR: {e}") from e
414415

415416

416-
# TODO: Figure out static compilation.
417+
# Compile Triton IR.
417418
def compile_triton_ir(
418419
code: str, ir_type: str, pipeline: List[Tuple[str, str]], dump_each: bool
419420
) -> str:
@@ -470,6 +471,7 @@ def compile_triton_ir(
470471
"triton_gpu_ir": "*.ttgir",
471472
"triton_llvm_ir": "*.llir",
472473
"triton_nvptx": "*.ptx",
474+
"triton_amdgpu": "*.hsaco",
473475
}
474476

475477
pattern = pattern_map.get(ir_type)
@@ -537,8 +539,20 @@ def process_model(request: CodeRequest) -> str:
537539
request.code, request.ir_type, pipeline, request.dump_after_each_opt
538540
)
539541

540-
if request.ir_type == "raw_ir" and request.selected_language == "pytorch":
541-
# Execute user Python, capture stdout.
542+
if request.ir_type == "raw_ir" and request.selected_language in (
543+
"pytorch",
544+
"triton",
545+
):
546+
# If raw IR is requested, we execute the user code directly.
547+
# Prepare a fake file for linecache to make
548+
# inspect.getsourcelines() work.
549+
fake_name = "<string>"
550+
source_code = request.code
551+
lines = [ln + "\n" for ln in source_code.splitlines()]
552+
linecache.cache[fake_name] = (
553+
len(source_code), None, lines, fake_name
554+
)
555+
# Execute user code, capture stdout.
542556
try:
543557
with tempfile.TemporaryDirectory() as tmpdir:
544558
stdout_path = os.path.join(tmpdir, "captured_output.txt")
@@ -555,10 +569,17 @@ def process_model(request: CodeRequest) -> str:
555569
captured, build_pipeline(request), request.dump_after_each_opt
556570
)
557571
except Exception as e:
558-
logger.exception("User code with manual IR print execution failed.")
559-
raise PytorchExecutionError(
560-
f"Code raised an exception during execution: {e}"
561-
) from e
572+
logger.exception(
573+
"User code with manual IR print execution failed."
574+
)
575+
if request.selected_language == "pytorch":
576+
raise PytorchExecutionError(
577+
f"Code raised an exception during execution: {e}"
578+
) from e
579+
else:
580+
raise TritonExecutionError(
581+
f"Triton code execution raised an exception: {e}"
582+
) from e
562583

563584
if request.ir_type == "raw_ir":
564585
return apply_optional_passes(

docs/integration_guide.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# PyTorch Explorer — Integration Guide
2-
*(last updated : 2025-06-04, commit: 5c5c42)*
2+
*(last updated : 2025-08-03)*
33

44
This document explains **how to integrate with the PyTorch explorer as IR-Playground compiler
55
service**—either by
@@ -32,6 +32,7 @@ service**—either by
3232
| `TORCH_MLIR_OPT_PATH` | Directory ending with `/` that contains `torch-mlir-opt` | `/opt/llvm/bin/` |
3333
| `LLVM_BIN_PATH` | Directory that contains `mlir-opt`, `mlir-translate`, `opt`, `llc` | `/opt/llvm/bin/` |
3434
| `TRITON_OPT_PATH` | Directory that contains `triton-opt`, `triton-llvm-opt` | `/opt/triton/bin/` |
35+
| `PYTORCH_INDEX` | Extra index URL used by `setup_backend.sh` to install PyTorch | `https://download.pytorch.org/whl/nightly/cpu` |
3536

3637
For the reference React UI (and any client based on it) set `NEXT_PUBLIC_BACKEND_URL`
3738
to point at the running backend instance when they live on different machines.
@@ -122,6 +123,7 @@ Calling this is optional but keeps /tmp tidy on long-running servers.
122123
| | triton_gpu_ir | *.ttgir | |
123124
| | triton_llvm_ir | *.llir | |
124125
| | triton_nvptx | *.ptx | |
126+
| | triton_amdgpu | *.hsaco | |
125127
| Raw | raw_ir | Echo-style (no generation) | |
126128

127129
## 4. ️**Backend internals (reference implementation)**
@@ -238,7 +240,7 @@ GET /version -> "ir-backend 1.1.0-rust"
238240

239241
- New dialect -> implement `generate_<dialect>()`, register it in `process_model`, add value to §3 and to the frontend dropdown.
240242
- New compiler tool -> add a clause in `apply_optional_passes`.
241-
- Timeouts / resource limits -> see `compile_triton_ir(... timeout=20)`.
243+
- Timeouts / resource limits -> see `compile_triton_ir(... timeout=60)`.
242244

243245
## 8. ️**Appendix — 20-line TypeScript helper**
244246

docs/pytorch_lowering.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,11 @@ Each stage can optionally be configured to **dump intermediate IR**, making the
131131

132132
---
133133

134-
For Triton models, see the separate section on `compile_triton_ir`, which handles IR extraction from the Triton JIT cache.
134+
For Triton models, see the separate section on `compile_triton_ir`, which handles IR extraction from the Triton JIT cache. Supported
135+
`ir_type` values include:
136+
137+
- `triton_ir` – Triton compiler dump (`*.ttir`)
138+
- `triton_gpu_ir``*.ttgir`
139+
- `triton_llvm_ir``*.llir`
140+
- `triton_nvptx``*.ptx`
141+
- `triton_amdgpu``*.hsaco`

setup_backend.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ source mlir_venv/bin/activate
3030

3131
echo "Installing torch-mlir and dependencies..."
3232
pip install --upgrade pip
33+
PYTORCH_INDEX="${PYTORCH_INDEX:-https://download.pytorch.org/whl/nightly/cpu}"
3334
pip install --pre torch-mlir torchvision \
34-
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
35+
--extra-index-url "$PYTORCH_INDEX" \
3536
-f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels
3637

38+
echo "Installing Triton..."
39+
pip install triton
40+
3741
echo "Installing FastAPI and Uvicorn..."
3842
pip install fastapi uvicorn
3943

src/app/ExplorerContent.js

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ model = MyModel()
2020
example_input = torch.randn(4, 4)
2121
# If you have multiple models, wrap each model and input tensor pair using:
2222
# __explore__(model, input_tensor)
23+
# To used your own means to compile Triton IR, please
24+
# select raw IR output.
2325
`;
2426

2527
const defaultTritonCode = `import triton
@@ -45,6 +47,8 @@ z = torch.empty_like(x)
4547
4648
grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
4749
add_kernel[grid](x, y, z, N)
50+
# To used your own means to compile Triton IR, please
51+
# select raw IR output.
4852
`;
4953

5054
const defaultRawIRCode = `module {
@@ -90,6 +94,8 @@ const tritonIROptions = [
9094
{ value: "triton_gpu_ir", label: "Triton GPU IR" },
9195
{ value: "triton_llvm_ir", label: "LLVM IR" },
9296
{ value: "triton_nvptx", label: "NVPTX" },
97+
{ value: "triton_amdgpu", label: "ROCm" },
98+
{ value: "raw_ir", label: "Raw IR Output" },
9399
];
94100

95101
const rawIROptions = [{ value: "raw_ir", label: "Raw IR Output" }];
@@ -554,8 +560,8 @@ export default function ExplorerContent() {
554560
style={{ margin: "10px 0" }}
555561
>
556562
<option value="pytorch">PyTorch</option>
563+
<option value="triton">Triton</option>
557564
<option value="raw_ir">Raw IR Input</option>
558-
<option value="triton">Triton (experimental support)</option>
559565
</select>
560566
<div
561567
style={{

0 commit comments

Comments
 (0)