Skip to content

Commit a21c24d

Browse files
Merge commit 'a8adf9bbc170ab43478e6a32424966f5cf78ef9a'
2 parents 061707a + a8adf9b commit a21c24d

File tree

7 files changed

+247
-117
lines changed

7 files changed

+247
-117
lines changed

python/src/ir.cc

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -567,29 +567,9 @@ void init_triton_ir(py::module &&m) {
567567
// .def("has_attr", &::FuncOp::hasAttr)
568568
.def("finalize",
569569
[](FuncOp &self) -> void {
570-
// Remove dead code
571-
// 1. Unreachable code after return
572-
self.walk([&](Block *block) {
573-
Operation *retOp = nullptr;
574-
// It's better to not use walk here because we only want to
575-
// check operations in the current block
576-
for (auto &op : block->getOperations()) {
577-
if (isa<ReturnOp>(op))
578-
if (retOp == nullptr) {
579-
retOp = &op;
580-
break;
581-
}
582-
}
583-
if (retOp && retOp != &block->back()) {
584-
auto pos = retOp->getIterator();
585-
pos++;
586-
auto *newBlock = block->splitBlock(pos);
587-
newBlock->erase();
588-
}
589-
});
590-
// 2. Check if the result of tl.advance is used
591-
self.walk([&](Operation *op) {
592-
if (isa<AdvanceOp>(op) && op->getResult(0).use_empty())
570+
// Check if the result of tl.advance is used
571+
self.walk([&](AdvanceOp op) {
572+
if (op->getResult(0).use_empty())
593573
outputWarning(op->getLoc(), "The result of tl.advance is not "
594574
"being used. Note that tl.advance "
595575
"does not have any side effects. "

python/test/unit/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import pytest
21
import os
2+
import pytest
33
import tempfile
44

55

66
def pytest_addoption(parser):
7-
parser.addoption("--device", action="store", default='cuda')
7+
parser.addoption("--device", action="store", default="cuda")
88

99

1010
@pytest.fixture

python/test/unit/language/test_core.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4903,6 +4903,33 @@ def nested_while(data, countPtr):
49034903
assert data[0] == 40
49044904

49054905

4906+
def test_constexpr_if_return(device):
4907+
# Reproducer for #4883, return statement in an if with a constexpr causes
4908+
# errors when combined with non-trivial control flow graphs
4909+
4910+
@triton.jit
4911+
def kernel(Semaphore, Out, total: tl.constexpr):
4912+
if total == 1:
4913+
tl.store(Out, tl.program_id(0))
4914+
return
4915+
4916+
prev = tl.atomic_add(Semaphore, 1)
4917+
if prev + 1 != total:
4918+
return
4919+
4920+
tl.store(Out, tl.program_id(0) + prev)
4921+
4922+
sem = torch.zeros((), device=device, dtype=torch.int32)
4923+
out = torch.empty((), device=device, dtype=torch.int32)
4924+
kernel[(1, )](sem, out, 1)
4925+
assert out.item() == 0
4926+
4927+
sem = torch.zeros((), device=device, dtype=torch.int32)
4928+
out = torch.full((), fill_value=-1, device=device, dtype=torch.int32)
4929+
kernel[(4, )](sem, out, 4)
4930+
assert out.item() >= 0
4931+
4932+
49064933
# -----------------------
49074934
# test extra
49084935
# -----------------------

python/test/unit/test_debug_dump.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,49 @@
1-
import triton
2-
import triton.language as tl
31
import os
2+
from contextlib import contextmanager
3+
44
import torch
5+
import triton
6+
import triton.language as tl
7+
58

9+
@contextmanager
10+
def enable_dump_context(pass_name="1"):
11+
try:
12+
os.environ["MLIR_ENABLE_DUMP"] = pass_name
13+
yield
14+
finally:
15+
os.environ["MLIR_ENABLE_DUMP"] = "0"
616

7-
def test_fn_dump(capfd, device):
17+
18+
def test_fn_dump(capfd, device, fresh_triton_cache):
819
N = 1024
920
src = torch.zeros(N, device=device)
1021

11-
grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), )
22+
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]), )
1223

1324
@triton.jit
1425
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
1526
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1627
x = tl.load(src + offsets, mask=offsets < N) + 1
1728
tl.store(src + offsets, x, mask=offsets < N)
1829

19-
os.environ['MLIR_ENABLE_DUMP'] = '1'
20-
BLOCK_SIZE = 16
21-
_kernel[grid](src, N, BLOCK_SIZE)
30+
with enable_dump_context():
31+
BLOCK_SIZE = 16
32+
_kernel[grid](src, N, BLOCK_SIZE)
2233
captured = capfd.readouterr()
34+
print(captured.err)
2335
assert "IR Dump Before" in captured.err
2436
assert "tt.func public @_kernel" in captured.err
2537

26-
os.environ['MLIR_ENABLE_DUMP'] = '_kernel'
27-
BLOCK_SIZE = 32
28-
_kernel[grid](src, N, BLOCK_SIZE)
38+
with enable_dump_context("_kernel"):
39+
BLOCK_SIZE = 32
40+
_kernel[grid](src, N, BLOCK_SIZE)
2941
captured = capfd.readouterr()
3042
assert "IR Dump Before" in captured.err
3143
assert "tt.func public @_kernel" in captured.err
3244

33-
os.environ['MLIR_ENABLE_DUMP'] = '_kernel2'
34-
BLOCK_SIZE = 64
35-
_kernel[grid](src, N, BLOCK_SIZE)
45+
with enable_dump_context("_kernel2"):
46+
BLOCK_SIZE = 64
47+
_kernel[grid](src, N, BLOCK_SIZE)
3648
captured = capfd.readouterr()
3749
assert "IR Dump Before" not in captured.err
38-
39-
os.environ['MLIR_ENABLE_DUMP'] = '0'
Lines changed: 121 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,106 @@
1-
import triton
2-
import triton.language as tl
31
import os
2+
from contextlib import contextmanager
3+
44
import pytest
55
import torch
6+
import triton
7+
import triton.language as tl
8+
9+
10+
@contextmanager
11+
def enable_remark_context():
12+
try:
13+
os.environ["MLIR_ENABLE_REMARK"] = "1"
14+
yield
15+
finally:
16+
os.environ["MLIR_ENABLE_REMARK"] = "0"
617

718

819
def is_perf_warning_enabled():
9-
return os.environ.get('MLIR_ENABLE_REMARK', '0') == '1'
20+
return os.environ.get("MLIR_ENABLE_REMARK", "0") == "1"
1021

1122

1223
def is_cuda():
1324
return triton.runtime.driver.active.get_current_target().backend == "cuda"
1425

1526

16-
def test_mma_remark(capfd):
27+
def test_mma_remark(capfd, fresh_triton_cache):
1728
if is_cuda():
1829
capability = torch.cuda.get_device_capability()
1930
if capability[0] < 9:
2031
pytest.skip("Requires sm >= 90 to run")
2132

22-
os.environ['MLIR_ENABLE_REMARK'] = '1'
23-
2433
@triton.jit
25-
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn):
26-
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),
27-
block_shape=(32, 128), order=(1, 0))
28-
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
29-
block_shape=(128, 32), order=(0, 1))
30-
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),
31-
block_shape=(32, 32), order=(1, 0))
34+
def matmul_kernel(
35+
a_ptr,
36+
b_ptr,
37+
c_ptr,
38+
M,
39+
N,
40+
K,
41+
stride_am,
42+
stride_ak,
43+
stride_bk,
44+
stride_bn,
45+
stride_cm,
46+
stride_cn,
47+
):
48+
a_block_ptr = tl.make_block_ptr(
49+
base=a_ptr,
50+
shape=(M, K),
51+
strides=(stride_am, stride_ak),
52+
offsets=(0, 0),
53+
block_shape=(32, 128),
54+
order=(1, 0),
55+
)
56+
b_block_ptr = tl.make_block_ptr(
57+
base=b_ptr,
58+
shape=(K, N),
59+
strides=(stride_bk, stride_bn),
60+
offsets=(0, 0),
61+
block_shape=(128, 32),
62+
order=(0, 1),
63+
)
64+
c_block_ptr = tl.make_block_ptr(
65+
base=c_ptr,
66+
shape=(M, N),
67+
strides=(stride_cm, stride_cn),
68+
offsets=(0, 0),
69+
block_shape=(32, 32),
70+
order=(1, 0),
71+
)
3272
a = tl.load(a_block_ptr)
3373
b = tl.load(b_block_ptr)
3474
c = tl.dot(a, b)
3575
tl.store(c_block_ptr, c)
3676

37-
triton.compile(
38-
triton.compiler.ASTSource(
39-
fn=matmul_kernel, signature={
40-
'a_ptr': '*fp32', 'b_ptr': '*fp32', 'c_ptr': '*fp32', 'M': 'i32', 'N': 'i32', 'K': 'i32', 'stride_am':
41-
'i32', 'stride_ak': 'i32', 'stride_bk': 'i32', 'stride_bn': 'i32', 'stride_cm': 'i32', 'stride_cn':
42-
'i32'
43-
}, constants={}))
77+
with enable_remark_context():
78+
triton.compile(
79+
triton.compiler.ASTSource(
80+
fn=matmul_kernel,
81+
signature={
82+
"a_ptr": "*fp32",
83+
"b_ptr": "*fp32",
84+
"c_ptr": "*fp32",
85+
"M": "i32",
86+
"N": "i32",
87+
"K": "i32",
88+
"stride_am": "i32",
89+
"stride_ak": "i32",
90+
"stride_bk": "i32",
91+
"stride_bn": "i32",
92+
"stride_cm": "i32",
93+
"stride_cn": "i32",
94+
},
95+
constants={},
96+
))
4497
captured = capfd.readouterr()
4598

46-
assert "remark: Warning: can't use MMA V3 for the dot op" in captured.err, "expect MMA V3 remark"
99+
assert ("remark: Warning: can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark"
47100
assert "note: see current operation:" in captured.err
48-
os.environ['MLIR_ENABLE_REMARK'] = '0'
49101

50102

51-
def test_remark_vectorization(capfd):
52-
os.environ["MLIR_ENABLE_REMARK"] = "1"
103+
def test_remark_vectorization(capfd, fresh_triton_cache):
53104

54105
@triton.jit
55106
def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr):
@@ -75,12 +126,52 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr)
75126
tl.store(out_ptr0 + (x4), tmp22, None)
76127

77128
XBLOCK = 1024
78-
triton.compile(
79-
triton.compiler.ASTSource(
80-
fn=ldst_vec, signature={
81-
'in_ptr0': '*i64', 'in_ptr1': '*i64', 'in_ptr2': '*fp16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp16'
82-
}, constants={"XBLOCK": XBLOCK}), options={"num_warps": 1})
129+
with enable_remark_context():
130+
triton.compile(
131+
triton.compiler.ASTSource(
132+
fn=ldst_vec,
133+
signature={
134+
"in_ptr0": "*i64",
135+
"in_ptr1": "*i64",
136+
"in_ptr2": "*fp16",
137+
"in_ptr3": "*fp32",
138+
"out_ptr0": "*fp16",
139+
},
140+
constants={"XBLOCK": XBLOCK},
141+
),
142+
options={"num_warps": 1},
143+
)
83144

84145
_, err = capfd.readouterr()
85146
assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark"
86-
os.environ["MLIR_ENABLE_REMARK"] = "0"
147+
148+
149+
def test_remark_swp_op_before_operands(capfd, fresh_triton_cache):
150+
151+
@triton.jit
152+
def kernel_pipe_error(in_ptr, out_ptr):
153+
SIZE: tl.constexpr = 64
154+
in_ptrs = in_ptr + tl.arange(0, SIZE)
155+
val = tl.zeros((SIZE, ), dtype=tl.float32)
156+
k = 0
157+
for i in tl.range(0, 64, num_stages=3):
158+
in_ptrs = in_ptr + tl.arange(0, SIZE) + SIZE * k
159+
val = tl.load(in_ptrs)
160+
out_ptrs = out_ptr + (tl.arange(0, SIZE) + i * SIZE)
161+
tl.store(out_ptrs, val)
162+
if tl.max(val) > 0:
163+
k += 1
164+
165+
with enable_remark_context():
166+
triton.compile(
167+
triton.compiler.ASTSource(
168+
fn=kernel_pipe_error,
169+
signature={"in_ptr": "*fp32", "out_ptr": "*fp32"},
170+
constants={},
171+
),
172+
options={"cluster_dims": (1, 1, 1)},
173+
)
174+
175+
_, err = capfd.readouterr()
176+
177+
assert "operation scheduled before its operands" in err, "expect swp op remark"

0 commit comments

Comments
 (0)