Skip to content

Commit 625c8cb

Browse files
authored
[BACKEND] Retain mlir reproducer temporaries from prior run pass pipelines (#8113)
Currently MLIR reproducers for each pass pipeline run overrides the previous `TRITON_REPRODUCER_PATH` path. This change allows for including a reproducer suffix when calling pm.run() to allow for retaining all previously run pipeline reproducers prior to the most recently run pass pipeline. This is important to add because with multiple pipelines, it is necessary to retain all previous pipelines reproducers to reproduce the full compilation sequence. - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/python/test` for end-to-end tests - Select one of the following. - [X] I have not added any `lit` tests.
1 parent 85e99d6 commit 625c8cb

File tree

4 files changed

+32
-18
lines changed

4 files changed

+32
-18
lines changed

python/src/ir.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1856,7 +1856,7 @@ void init_triton_ir(py::module &&m) {
18561856
})
18571857
.def(
18581858
"run",
1859-
[](PassManager &self, ModuleOp &mod) {
1859+
[](PassManager &self, ModuleOp &mod, std::string repro_pipeline_tag) {
18601860
// TODO: maybe dump module to file and print error for better
18611861
// diagnostics
18621862

@@ -1867,6 +1867,11 @@ void init_triton_ir(py::module &&m) {
18671867
auto reproducerPath =
18681868
triton::tools::getStrEnv("TRITON_REPRODUCER_PATH");
18691869
if (!reproducerPath.empty()) {
1870+
if (reproducerPath != "-") {
1871+
std::string repro_suffix =
1872+
"." + repro_pipeline_tag + ".repro.mlir";
1873+
reproducerPath += repro_suffix;
1874+
}
18701875
auto anchorName = self.getOpAnchorName();
18711876
auto passes = self.getPasses();
18721877
Operation *op = mod.getOperation();
Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import triton
22
import re
3+
import os
34

45

56
def test_triton_reproducer_path(monkeypatch, tmp_path):
@@ -13,17 +14,25 @@ def triton_():
1314
# We need an temp empty file for MLIR to write the reproducer to, and then
1415
# the TRITON_REPRODUCER_PATH env var enables crash the reproduction
1516
# generation in MLIR.
16-
repro_path = tmp_path / "repro.mlir"
17-
repro_path.touch()
17+
repro_path = tmp_path / "repro_prefix"
1818
monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path))
1919

2020
# Run the kernel so MLIR will generate a crash reproducer. It doesn't really
2121
# matter what the kernel does, just that the PassManager runs its passes.
2222
triton_[(1, )]()
2323

24-
repro = repro_path.read_text()
25-
assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {repro_path}. Got:\n{repro}"
26-
m = re.search(r"pipeline: \"(.*)\"", repro)
27-
assert m, "Expected to match pass pipeline after \"pipeline:\" in MLIR reproducer"
28-
pipeline_str = m.group(1)
29-
assert pipeline_str, "Expected non-empty pass pipeline in MLIR reproducer"
24+
stages = {
25+
'make_ttir': "triton-combine",
26+
'make_ttgir': "triton.*-coalesce",
27+
'make_llir': "convert-triton-.*gpu-to-llvm",
28+
}
29+
30+
for stage_name, stage_pipeline_check in stages.items():
31+
assert os.path.exists(str(repro_path) + '.' + stage_name + '.repro.mlir')
32+
curr_repro_path = tmp_path / ("repro_prefix." + stage_name + ".repro.mlir")
33+
repro = curr_repro_path.read_text()
34+
assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {curr_repro_path}. Got:\n{repro}"
35+
m = re.search(r"pipeline: \"(.*" + stage_pipeline_check + ".*)\"", repro)
36+
assert m, "Expected to match pass pipeline after \"pipeline:\" in MLIR reproducer"
37+
pipeline_str = m.group(1)
38+
assert pipeline_str, "Expected non-empty pass pipeline in MLIR reproducer"

third_party/amd/backend/compiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def make_ttir(mod, metadata, options):
196196
passes.ttir.add_triton_licm(pm)
197197
passes.common.add_symbol_dce(pm)
198198
passes.ttir.add_loop_unroll(pm)
199-
pm.run(mod)
199+
pm.run(mod, 'make_ttir')
200200
return mod
201201

202202
@staticmethod
@@ -205,7 +205,7 @@ def make_ttgir(mod, metadata, options):
205205
pm.enable_debug()
206206
passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size,
207207
options.num_ctas)
208-
pm.run(mod)
208+
pm.run(mod, 'make_ttgir_early')
209209
pm = ir.pass_manager(mod.context)
210210
pm.enable_debug()
211211
passes.ttgpuir.add_coalesce(pm)
@@ -254,7 +254,7 @@ def make_ttgir(mod, metadata, options):
254254
passes.common.add_symbol_dce(pm)
255255
if use_async_copy:
256256
amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
257-
pm.run(mod)
257+
pm.run(mod, 'make_ttgir')
258258
return mod
259259

260260
@staticmethod
@@ -270,7 +270,7 @@ def gluon_to_ttgir(src, metadata, options):
270270
passes.gluon.add_canonicalizer(pm)
271271
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
272272

273-
pm.run(mod)
273+
pm.run(mod, 'gluon_to_ttgir')
274274
return mod
275275

276276
@staticmethod
@@ -323,7 +323,7 @@ def make_llir(src, metadata, options):
323323
passes.llvmir.add_di_scope(pm)
324324

325325
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
326-
pm.run(mod)
326+
pm.run(mod, 'make_llir')
327327

328328
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
329329
llvm.init_targets()

third_party/nvidia/backend/compiler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def make_ttir(mod, metadata, opt, capability):
239239
passes.common.add_cse(pm)
240240
passes.common.add_symbol_dce(pm)
241241
passes.ttir.add_loop_unroll(pm)
242-
pm.run(mod)
242+
pm.run(mod, 'make_ttir')
243243
return mod
244244

245245
@staticmethod
@@ -316,7 +316,7 @@ def make_ttgir(mod, metadata, opt, capability):
316316
passes.common.add_cse(pm)
317317
passes.common.add_canonicalizer(pm)
318318

319-
pm.run(mod)
319+
pm.run(mod, 'make_ttgir')
320320
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
321321
tensordesc_meta = mod.get_tensordesc_metadata()
322322
metadata["tensordesc_meta"] = tensordesc_meta
@@ -334,7 +334,7 @@ def gluon_to_ttgir(self, src, metadata, options, capability):
334334
passes.gluon.add_canonicalizer(pm)
335335
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
336336

337-
pm.run(mod)
337+
pm.run(mod, 'gluon_to_ttgir')
338338
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
339339
return mod
340340

@@ -374,7 +374,7 @@ def make_llir(self, src, metadata, options, capability):
374374
if CUDABackend.instrumentation:
375375
CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
376376

377-
pm.run(mod)
377+
pm.run(mod, 'make_llir')
378378
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
379379
llvm.init_targets()
380380
context = llvm.context()

0 commit comments

Comments
 (0)