|
1 | | -import os |
2 | | -import shutil |
3 | | - |
4 | | -import pytest |
5 | | - |
6 | | -import torch |
7 | 1 | import triton |
8 | 2 | import re |
9 | 3 |
|
10 | 4 |
|
11 | | -@triton.jit |
12 | | -def triton_(): |
13 | | - return |
| 5 | +def test_triton_reproducer_path(monkeypatch, tmp_path): |
| 6 | + # If we get a cache hit there will be no reproducer generated |
| 7 | + monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1") |
| 8 | + |
| 9 | + @triton.jit |
| 10 | + def triton_(): |
| 11 | + return |
14 | 12 |
|
| 13 | + # We need an temp empty file for MLIR to write the reproducer to, and then |
| 14 | + # the TRITON_REPRODUCER_PATH env var enables crash the reproduction |
| 15 | + # generation in MLIR. |
| 16 | + repro_path = tmp_path / "repro.mlir" |
| 17 | + repro_path.touch() |
| 18 | + monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path)) |
15 | 19 |
|
16 | | -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") |
17 | | -def test_reproducer(): |
18 | | - tmpdir = ".tmp" |
19 | | - reproducer = 'triton-reproducer.mlir' |
20 | | - if os.path.exists(tmpdir): |
21 | | - shutil.rmtree(tmpdir, ignore_errors=True) |
22 | | - if os.path.exists(reproducer): |
23 | | - os.remove(reproducer) |
24 | | - os.environ["TRITON_CACHE_DIR"] = tmpdir |
25 | | - os.environ["TRITON_REPRODUCER_PATH"] = reproducer |
| 20 | + # Run the kernel so MLIR will generate a crash reproducer. It doesn't really |
| 21 | + # matter what the kernel does, just that the PassManager runs its passes. |
26 | 22 | triton_[(1, )]() |
27 | | - foundPipeline = "" |
28 | | - with open(reproducer, 'r') as f: |
29 | | - line = f.read() |
30 | | - if 'pipeline:' in line: |
31 | | - foundPipeline = line |
32 | | - if 0 == len(foundPipeline): |
33 | | - raise Exception("Failed to find pipeline info in reproducer file.") |
34 | 23 |
|
35 | | - ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm") |
36 | | - if ttgir_to_llvm_pass.search(foundPipeline): |
37 | | - raise Exception("Failed to find triton passes in pipeline") |
38 | | - # cleanup |
39 | | - if os.path.exists(tmpdir): |
40 | | - shutil.rmtree(tmpdir, ignore_errors=True) |
41 | | - if os.path.exists(reproducer): |
42 | | - os.remove(reproducer) |
| 24 | + repro = repro_path.read_text() |
| 25 | + assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {repro_path}. Got:\n{repro}" |
| 26 | + m = re.search(r"pipeline: \"(.*)\"", repro) |
| 27 | + assert m, "Expected to match pass pipeline after \"pipeline:\" in MLIR reproducer" |
| 28 | + pipeline_str = m.group(1) |
| 29 | + assert pipeline_str, "Expected non-empty pass pipeline in MLIR reproducer" |
0 commit comments