Skip to content

Commit 2347dba

Browse files
anmyachevpbchekin
andauthored
Reapply "[FRONTEND] added support for tuples (#5220)" (#3043)
This reverts commit 492ea92. Summary of changes: * support for tuples * `constexprs` are now also part of the signature. The format is: `signature={..., 'o': "*fp32", '[Name of constexpr]': 'constexpr'}` * For `ASTSource` there is parameter name change: `constants` -> `constexprs` * New nesting level has appeared for the main properties (defined in `_add_common_properties`): `divisibility_16`, `equal_to_1` and `equal_to_none`. --------- Signed-off-by: Anatoly Myachev <[email protected]> Co-authored-by: Pavel Chekin <[email protected]>
1 parent 6abfacb commit 2347dba

File tree

25 files changed

+929
-302
lines changed

25 files changed

+929
-302
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ repos:
66
- id: check-symlinks
77
- id: destroyed-symlinks
88
- id: trailing-whitespace
9+
exclude: .*.patch
910
- id: end-of-file-fixer
1011
- id: check-yaml
1112
- id: check-toml

python/src/ir.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) {
605605
"Function argument index out of range");
606606
return self.getArgument(idx);
607607
})
608+
.def("get_num_args", &FuncOp::getNumArguments)
608609
.def(
609610
"add_entry_block",
610611
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },

python/test/unit/language/test_compile_errors.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def kernel():
1717
a += 1 # noqa
1818

1919
with pytest.raises(CompilationError) as e:
20-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
20+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
2121

2222
try:
2323
assert "is not defined" in str(e.value), "error should mention the undefined variable"
@@ -32,7 +32,7 @@ def kernel():
3232
0 + "a"
3333

3434
with pytest.raises(CompilationError) as e:
35-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
35+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
3636

3737
try:
3838
assert "at 2:4:" in str(e.value), "error should point to the 0"
@@ -47,7 +47,7 @@ def kernel():
4747
tl.static_assert(isinstance(0, tl.tensor))
4848

4949
with pytest.raises(CompilationError) as e:
50-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
50+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
5151

5252
try:
5353
assert isinstance(e.value, CompileTimeAssertionFailure)
@@ -66,7 +66,7 @@ def kernel():
6666
not (0, 0)
6767

6868
with pytest.raises(CompilationError) as e:
69-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
69+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
7070

7171
try:
7272
assert e.value.__cause__ is None
@@ -83,7 +83,7 @@ def kernel():
8383
1.0 << 1
8484

8585
with pytest.raises(CompilationError) as e:
86-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
86+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
8787

8888
try:
8989
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
@@ -107,7 +107,7 @@ def kernel():
107107
nested_call()
108108

109109
with pytest.raises(CompilationError) as e:
110-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
110+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
111111

112112
try:
113113
inner = e.value.__cause__
@@ -130,7 +130,7 @@ def kernel():
130130
tl.expand_dims(None, -1)
131131

132132
with pytest.raises(CompilationError) as e:
133-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
133+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
134134

135135
try:
136136
inner = e.value.__cause__
@@ -157,7 +157,7 @@ def kernel():
157157
a = two_returns()
158158
a + tl.arange(0, 4) # only works if we took the first return
159159

160-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
160+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
161161

162162

163163
def test_not_const_annotate_no_err():
@@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
166166
def kernel(N: int = 1):
167167
pass
168168

169-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
169+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
170170

171171

172172
@triton.jit
@@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
186186
a = returns_branched_on_constexpr(N)
187187
a + tl.arange(0, 4)
188188

189-
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))
189+
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0}))
190190

191191
@triton.jit
192192
def kernel2(N: tl.constexpr):
193193
a = returns_branched_on_constexpr(N)
194194
a + tl.arange(0, 8)
195195

196-
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))
196+
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1}))
197197

198198

199199
@triton.jit
@@ -211,7 +211,7 @@ def kernel(N: int):
211211
returns_branched_on_non_constexpr(N)
212212

213213
with pytest.raises(CompilationError) as e:
214-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
214+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
215215

216216
try:
217217
assert "at 2:4:" in str(e.value), "error should point to the function call"
@@ -227,7 +227,7 @@ def kernel():
227227
tl.arange(2, 7)
228228

229229
with pytest.raises(CompilationError) as e:
230-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
230+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
231231
assert str(e.value.__cause__) == "arange's range must be a power of 2"
232232

233233

@@ -238,7 +238,7 @@ def kernel():
238238
tl.full((33, ), 0, dtype=tl.int64)
239239

240240
with pytest.raises(CompilationError) as e:
241-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
241+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
242242
assert str(e.value.__cause__) == "Shape element 0 must be a power of 2"
243243

244244

@@ -251,7 +251,7 @@ def kernel():
251251
a = CAPTURED # noqa
252252

253253
with pytest.raises(CompilationError) as e:
254-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
254+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
255255
assert "CAPTURED is not defined" in str(e.value)
256256

257257

@@ -265,7 +265,7 @@ def kernel():
265265
a = GLOBAL # noqa
266266

267267
with pytest.raises(CompilationError) as e:
268-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
268+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
269269
assert "global variable" in str(e.value)
270270

271271

@@ -279,7 +279,7 @@ def kernel():
279279
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa
280280

281281
# No error.
282-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
282+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
283283

284284

285285
CONSTEXPR_GLOBAL = tl.constexpr(42)
@@ -292,7 +292,7 @@ def kernel():
292292
a = CONSTEXPR_GLOBAL # noqa
293293

294294
# No error.
295-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
295+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
296296

297297

298298
TYPE_ALIAS = tl.pointer_type(tl.int32)
@@ -305,7 +305,7 @@ def kernel():
305305
a = TYPE_ALIAS # noqa
306306

307307
# No error.
308-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
308+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
309309

310310

311311
def test_global_access_in_fn_default_arg():
@@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
315315
pass
316316

317317
# No error.
318-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))
318+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={}))
319319

320320

321321
def test_defaults_assign_no_err():
@@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
324324
def kernel(a=1, B: tl.constexpr = ""):
325325
pass
326326

327-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))
327+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""}))
328328

329329

330330
def test_where_warning(fresh_triton_cache):
@@ -337,7 +337,7 @@ def kernel():
337337
tl.where(a, b, c)
338338

339339
with pytest.warns(UserWarning):
340-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
340+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
341341

342342

343343
@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
@@ -371,7 +371,8 @@ def dtype_kernel(dtype: tl.constexpr):
371371
ctx = pytest.raises(CompilationError, match="")
372372

373373
with ctx as e:
374-
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))
374+
triton.compile(
375+
triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))
375376

376377
if dtype not in supported_dtypes:
377378
try:
@@ -426,7 +427,7 @@ def dot_kernel():
426427
tl.dot(a, b, max_num_imprecise_acc=128)
427428

428429
with pytest.raises(CompilationError) as e:
429-
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={}))
430+
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={}))
430431
try:
431432
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
432433
except AssertionError as assertion_err:

python/test/unit/language/test_core.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4494,15 +4494,17 @@ def kernel(x):
44944494
def test_value_specialization(value: int, value_type: str, device) -> None:
44954495

44964496
def repr(specialization):
4497-
spec_type = specialization.signature["VALUE"]
4498-
return f"kernel_{spec_type}"
4497+
ty = specialization.signature["value1"]
4498+
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
4499+
return f"kernel_{ty}_{cst}"
44994500

45004501
@triton.jit(repr=repr)
4501-
def kernel(VALUE, X):
4502+
def kernel(value1, is_one, X):
45024503
pass
45034504

45044505
x = torch.tensor([3.14159], device=device)
4505-
h = kernel[(1, )](value, x)
4506+
h = kernel[(1, )](value, 1, x)
4507+
assert "is_one" in h.name
45064508
assert value_type in h.name
45074509

45084510

@@ -6346,6 +6348,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r
63466348
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))
63476349

63486350

6351+
def test_dtype(device):
6352+
6353+
@triton.jit
6354+
def kernel(X):
6355+
dtype_x: tl.constexpr = X.dtype.element_ty
6356+
tl.static_assert(dtype_x == tl.int32)
6357+
tl.static_assert(dtype_x == tl.constexpr(tl.int32))
6358+
tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32))
6359+
6360+
X = torch.zeros(1, dtype=torch.int32, device=device)
6361+
kernel[(1, )](X)
6362+
6363+
63496364
def test_side_effectful_scan(device):
63506365
if device != "cuda":
63516366
pytest.xfail()

python/test/unit/language/test_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def kernel():
2323
pass
2424

2525
try:
26-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
26+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
2727
except Exception as e:
2828
pytest.fail(f"triton compile failed with error: {e}")
2929

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pytest
2+
import triton
3+
import triton.language as tl
4+
import torch
5+
6+
7+
@triton.jit
8+
def _tuple_increment(values):
9+
for i in tl.static_range(len(values)):
10+
values[i] = values[i] + 1
11+
return values
12+
13+
14+
@triton.jit
15+
def _tuple_index_func(Ptrs, values):
16+
for i in tl.static_range(len(values)):
17+
tl.store(Ptrs[i], values[i])
18+
19+
20+
@triton.jit
21+
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
22+
values = _tuple_increment(values)
23+
_tuple_index_func(Ptrs, values)
24+
25+
26+
@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
27+
def test_index(size, device="xpu"):
28+
vals = tuple([i + 1 for i in range(size)])
29+
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
30+
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
31+
assert vals == tuple([x.item() - 1 for x in rets])
32+
33+
34+
# ----
35+
36+
37+
@triton.jit
38+
def _tuple_assign(XPtrs, YPtrs, values):
39+
# assign from tuple
40+
X0, X1 = XPtrs
41+
x0, x1 = values
42+
tl.store(X0, x0)
43+
tl.store(X1, x1)
44+
# assign to tuple
45+
Y0, Y1, Y2 = YPtrs
46+
Y = Y0, Y1, Y2
47+
y = x0, 10, x1
48+
tl.store(Y[0], y[0])
49+
tl.store(Y[1], y[1])
50+
tl.store(Y[2], y[2])
51+
52+
53+
def test_assign(device="xpu"):
54+
vals = (2., 3.)
55+
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
56+
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
57+
_tuple_assign[(1, )](x, y, vals)
58+
assert x[0] == vals[0]
59+
assert x[1] == vals[1]
60+
assert y[0] == vals[0]
61+
assert y[1] == 10
62+
assert y[2] == vals[1]
63+
64+
65+
# -------
66+
67+
68+
@triton.jit
69+
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
70+
tl.static_assert(tuple1[1] is None)
71+
tl.store(Ptr + 5, cst2)
72+
tl.store(Ptr + 6, tuple1[0])
73+
tl.store(Ptr + 7, tl.load(tuple1[2][0]))
74+
tl.store(Ptr + 8, tuple1[2][1][0])
75+
tl.store(Ptr + 9, tl.load(tuple1[2][1][2]))
76+
77+
78+
# test serialization/deserialization of tuple arguments in
79+
# the frontend.
80+
@triton.jit
81+
def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
82+
tl.static_assert(N1 is None)
83+
tl.static_assert(tuple1[1][1] is None)
84+
tl.store(Ptr + 0, tl.load(tuple1[0]))
85+
tl.store(Ptr + 1, tuple1[1][0])
86+
tl.store(Ptr + 2, tl.load(tuple1[1][2]))
87+
tl.store(Ptr + 3, cst1 + val1)
88+
tl.store(Ptr + 4, tl.load(tuple2[0]))
89+
_tuple_fn0(Ptr, 15, (-1, None, tuple1))
90+
91+
92+
def test_serialize(device="xpu"):
93+
x0 = torch.tensor([8], dtype=torch.int32, device=device)
94+
x1 = torch.tensor([12], dtype=torch.int32, device=device)
95+
y0 = torch.tensor([10], dtype=torch.int32, device=device)
96+
z = torch.empty((10, ), dtype=torch.int32, device=device)
97+
# we want to check that JIT specialization propagates to tuples:
98+
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, ))
99+
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
100+
assert torch.equal(z, ref)

0 commit comments

Comments
 (0)