Skip to content

Commit 8203b1d

Browse files
Merge commit 'a85fab01dd6e35429f530c3b3886b33e8615366a'
2 parents 69ceddb + a85fab0 commit 8203b1d

File tree

12 files changed

+367
-139
lines changed

12 files changed

+367
-139
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ arbitrary LLVM version.
138138
during the build. By default, this is the user's home directory. It
139139
can be changed anytime.
140140

141+
- If you're running out of memory when building Triton, specify the `MAX_JOBS`
142+
environment variable (to the `pip install -e python` command) to limit the
143+
number of jobs.
144+
141145
- Pass `--no-build-isolation` to `pip install` to make nop builds faster.
142146
Without this, every invocation of `pip install` uses a different symlink to
143147
cmake, and this forces ninja to rebuild most of the `.a` files.

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3333
"TRITON_ENABLE_LLVM_DEBUG",
3434
"TRITON_HIP_GLOBAL_PREFETCH",
3535
"TRITON_HIP_LOCAL_PREFETCH",
36+
"TRITON_HIP_USE_ASYNC_COPY",
3637
"TRITON_HIP_USE_BLOCK_PINGPONG",
3738
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",
3839
"TRITON_LLVM_DEBUG_ONLY",
Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,29 @@
1-
import os
2-
import shutil
3-
4-
import pytest
5-
6-
import torch
71
import triton
82
import re
93

104

11-
@triton.jit
12-
def triton_():
13-
return
5+
def test_triton_reproducer_path(monkeypatch, tmp_path):
6+
# If we get a cache hit there will be no reproducer generated
7+
monkeypatch.setenv("TRITON_ALWAYS_COMPILE", "1")
8+
9+
@triton.jit
10+
def triton_():
11+
return
1412

13+
# We need an temp empty file for MLIR to write the reproducer to, and then
14+
# the TRITON_REPRODUCER_PATH env var enables crash the reproduction
15+
# generation in MLIR.
16+
repro_path = tmp_path / "repro.mlir"
17+
repro_path.touch()
18+
monkeypatch.setenv("TRITON_REPRODUCER_PATH", str(repro_path))
1519

16-
@pytest.mark.skipif(not torch.xpu.is_available() and not torch.cuda.is_available(), reason="requires cuda")
17-
def test_reproducer():
18-
tmpdir = ".tmp"
19-
reproducer = 'triton-reproducer.mlir'
20-
if os.path.exists(tmpdir):
21-
shutil.rmtree(tmpdir, ignore_errors=True)
22-
if os.path.exists(reproducer):
23-
os.remove(reproducer)
24-
os.environ["TRITON_CACHE_DIR"] = tmpdir
25-
os.environ["TRITON_REPRODUCER_PATH"] = reproducer
20+
# Run the kernel so MLIR will generate a crash reproducer. It doesn't really
21+
# matter what the kernel does, just that the PassManager runs its passes.
2622
triton_[(1, )]()
27-
foundPipeline = ""
28-
with open(reproducer, 'r') as f:
29-
line = f.read()
30-
if 'pipeline:' in line:
31-
foundPipeline = line
32-
if 0 == len(foundPipeline):
33-
raise Exception("Failed to find pipeline info in reproducer file.")
3423

35-
ttgir_to_llvm_pass = re.compile("convert-triton-{{.*}}gpu-to-llvm")
36-
if ttgir_to_llvm_pass.search(foundPipeline):
37-
raise Exception("Failed to find triton passes in pipeline")
38-
# cleanup
39-
if os.path.exists(tmpdir):
40-
shutil.rmtree(tmpdir, ignore_errors=True)
41-
if os.path.exists(reproducer):
42-
os.remove(reproducer)
24+
repro = repro_path.read_text()
25+
assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {repro_path}. Got:\n{repro}"
26+
m = re.search(r"pipeline: \"(.*)\"", repro)
27+
assert m, "Expected to match pass pipeline after \"pipeline:\" in MLIR reproducer"
28+
pipeline_str = m.group(1)
29+
assert pipeline_str, "Expected non-empty pass pipeline in MLIR reproducer"

python/tutorials/09-persistent-matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,9 @@ def bench(K, dtype, reps=1000, warmup_reps=10000):
725725
if dtype == torch.float16:
726726
bench_fn(reps, warmup_reps, torch_matmul, a, b)
727727
bench_fn(reps, warmup_reps, matmul, a, b.T)
728-
bench_fn(reps, warmup_reps, matmul_tma_ws, a, b)
729728
bench_fn(reps, warmup_reps, matmul_persistent, a, b.T)
730729
if supports_tma():
730+
bench_fn(reps, warmup_reps, matmul_tma_ws, a, b)
731731
bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b)
732732
bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b)
733733

test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
2727
tt.return
2828
}
2929

30-
// CHECK-LABEL: wmma1_dot
31-
tt.func @wmma1_dot(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) {
30+
// CHECK-LABEL: wmma1_dot_f16
31+
tt.func @wmma1_dot_f16(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xf16, #mma1>) {
3232
// CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
3333
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
3434
// CHECK: llvm.mlir.undef : vector<16xf16>
3535
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xf16>
36-
// CHECK: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
36+
// CHECK: wmma.f16.16x16x16.f16{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
3737
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xf16, #mma1>
3838
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
3939
// CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)>
@@ -50,11 +50,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
5050
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
5151
// CHECK: llvm.mlir.undef : vector<16xbf16>
5252
// CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<16xbf16>
53-
// CHECK: rocdl.wmma.bf16.16x16x16.bf16 {{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
53+
// CHECK: wmma.bf16.16x16x16.bf16{{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
5454
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xbf16, #mma1>
5555
tt.return
5656
}
5757

58+
// CHECK-LABEL: wmma1_dot_f16_tied
59+
tt.func @wmma1_dot_f16_tied(%arg0: tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<64x16xf16, #mma1>) {
60+
// CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
61+
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
62+
// CHECK: llvm.mlir.undef : vector<16xf16>
63+
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xf16>
64+
// CHECK-COUNT-2: wmma.f16.16x16x16.f16.tied{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
65+
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<64x16xf16, #mma1>
66+
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xf16>
67+
// CHECK: llvm.mlir.undef : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
68+
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16, f16)>
69+
tt.return
70+
}
71+
72+
// CHECK-LABEL: wmma1_dot_bf16_tied
73+
tt.func @wmma1_dot_bf16_tied(%arg0: tensor<64x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<64x16xbf16, #mma1>) {
74+
// CHECK-COUNT-32: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
75+
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
76+
// CHECK: llvm.mlir.undef : vector<16xbf16>
77+
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xbf16>
78+
// CHECK-COUNT-2: wmma.bf16.16x16x16.bf16.tied{{.*}} : (vector<16xi16>, vector<16xi16>, vector<16xbf16>, i1) -> vector<16xbf16>
79+
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<64x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<64x16xbf16, #mma1>
80+
// CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<16xbf16>
81+
// CHECK: llvm.mlir.undef : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
82+
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16, bf16)>
83+
tt.return
84+
}
85+
5886
// CHECK-LABEL: wmma1_dot_int8_32
5987
tt.func @wmma1_dot_int8_32(%arg0: tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg2: tensor<16x16xi32, #mma1>) {
6088
// CHECK-COUNT-16: llvm.extractvalue %{{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)>
@@ -64,7 +92,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
6492
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi8>
6593
// CHECK: llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32>
6694
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
67-
// CHECK: rocdl.wmma.i32.16x16x16.iu8 {{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
95+
// CHECK: wmma.i32.16x16x16.iu8{{.*}} : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
6896
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1>
6997
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
7098
tt.return
@@ -79,7 +107,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
79107
// CHECK-COUNT-16: llvm.insertelement {{.*}} : vector<16xi4>
80108
// CHECK: llvm.bitcast %{{.*}} : vector<16xi4> to vector<2xi32>
81109
// CHECK-COUNT-8: llvm.extractvalue %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
82-
// CHECK: rocdl.wmma.i32.16x16x16.iu4 {{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
110+
// CHECK: wmma.i32.16x16x16.iu4{{.*}} : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
83111
%0 = tt.dot %arg0, %arg1, %arg2 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xi4, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<16x16xi4, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<16x16xi32, #mma1>
84112
// CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
85113
tt.return
@@ -196,7 +224,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
196224
// CHECK-COUNT-32: llvm.insertelement
197225
// CHECK-COUNT-8: llvm.extractvalue %arg2
198226
// CHECK-COUNT-8: llvm.insertelement
199-
// CHECK-COUNT-2: rocdl.wmma.f16.16x16x16.f16 {{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
227+
// CHECK-COUNT-2: wmma.f16.16x16x16.f16{{.*}} : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
200228
%0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<2x16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> * tensor<2x32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>> -> tensor<2x16x16xf16, #mma1>
201229
// CHECK-COUNT-8: llvm.extractelement
202230
// CHECK-COUNT-8: llvm.insertvalue

0 commit comments

Comments
 (0)