Skip to content

Commit a85fab0

Browse files
authored
[TEST] Refactor reproducer test to be device-agnostic (#6230)
I have refactored the existing reproducer test to make better use of pytest fixtures and be device-independent. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `this PR only updates the LLVM pin, so CI is sufficient`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 678f083 commit a85fab0

File tree

1 file changed

+21
-34
lines changed

1 file changed

+21
-34
lines changed
Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,29 @@
1-
import os
2-
import shutil
3-
4-
import pytest
5-
6-
import torch
71
import triton
82
import re
93

104

11-
@triton.jit
12-
def triton_():
13-
return
5+
def test_triton_reproducer_path(monkeypatch, tmp_path):
6+
# If we get a cache hit there will be no reproducer generated
7+
monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1")
8+
9+
@triton.jit
10+
def triton_():
11+
return
1412

13+
# We need an temp empty file for MLIR to write the reproducer to, and then
14+
# the TRITON_REPRODUCER_PATH env var enables crash the reproduction
15+
# generation in MLIR.
16+
repro_path = tmp_path / "repro.mlir"
17+
repro_path.touch()
18+
monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path))
1519

16-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
17-
def test_reproducer():
18-
tmpdir = ".tmp"
19-
reproducer = 'triton-reproducer.mlir'
20-
if os.path.exists(tmpdir):
21-
shutil.rmtree(tmpdir, ignore_errors=True)
22-
if os.path.exists(reproducer):
23-
os.remove(reproducer)
24-
os.environ["TRITON_CACHE_DIR"] = tmpdir
25-
os.environ["TRITON_REPRODUCER_PATH"] = reproducer
20+
# Run the kernel so MLIR will generate a crash reproducer. It doesn't really
21+
# matter what the kernel does, just that the PassManager runs its passes.
2622
triton_[(1, )]()
27-
foundPipeline = ""
28-
with open(reproducer, 'r') as f:
29-
line = f.read()
30-
if 'pipeline:' in line:
31-
foundPipeline = line
32-
if 0 == len(foundPipeline):
33-
raise Exception("Failed to find pipeline info in reproducer file.")
3423

35-
ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm")
36-
if ttgir_to_llvm_pass.search(foundPipeline):
37-
raise Exception("Failed to find triton passes in pipeline")
38-
# cleanup
39-
if os.path.exists(tmpdir):
40-
shutil.rmtree(tmpdir, ignore_errors=True)
41-
if os.path.exists(reproducer):
42-
os.remove(reproducer)
24+
repro = repro_path.read_text()
25+
assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {repro_path}. Got:\n{repro}"
26+
m = re.search(r"pipeline: \"(.*)\"", repro)
27+
assert m, "Expected to match pass pipeline after \"pipeline:\" in MLIR reproducer"
28+
pipeline_str = m.group(1)
29+
assert pipeline_str, "Expected non-empty pass pipeline in MLIR reproducer"

0 commit comments

Comments
 (0)