Skip to content

Commit 492ea92

Browse files
Revert "[FRONTEND] added support for tuples (#5220)"
This reverts commit 9743ec0.
1 parent 2c10050 commit 492ea92

File tree

21 files changed

+288
-635
lines changed

21 files changed

+288
-635
lines changed

python/src/ir.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,6 @@ void init_triton_ir(py::module &&m) {
605605
"Function argument index out of range");
606606
return self.getArgument(idx);
607607
})
608-
.def("get_num_args", &FuncOp::getNumArguments)
609608
.def(
610609
"add_entry_block",
611610
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },

python/test/unit/language/test_compile_errors.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def kernel():
1717
a += 1 # noqa
1818

1919
with pytest.raises(CompilationError) as e:
20-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
20+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
2121

2222
try:
2323
assert "is not defined" in str(e.value), "error should mention the undefined variable"
@@ -32,7 +32,7 @@ def kernel():
3232
0 + "a"
3333

3434
with pytest.raises(CompilationError) as e:
35-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
35+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
3636

3737
try:
3838
assert "at 2:4:" in str(e.value), "error should point to the 0"
@@ -47,7 +47,7 @@ def kernel():
4747
tl.static_assert(isinstance(0, tl.tensor))
4848

4949
with pytest.raises(CompilationError) as e:
50-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
50+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
5151

5252
try:
5353
assert isinstance(e.value, CompileTimeAssertionFailure)
@@ -66,7 +66,7 @@ def kernel():
6666
not (0, 0)
6767

6868
with pytest.raises(CompilationError) as e:
69-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
69+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
7070

7171
try:
7272
assert e.value.__cause__ is None
@@ -83,7 +83,7 @@ def kernel():
8383
1.0 << 1
8484

8585
with pytest.raises(CompilationError) as e:
86-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
86+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
8787

8888
try:
8989
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
@@ -107,7 +107,7 @@ def kernel():
107107
nested_call()
108108

109109
with pytest.raises(CompilationError) as e:
110-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
110+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
111111

112112
try:
113113
inner = e.value.__cause__
@@ -130,7 +130,7 @@ def kernel():
130130
tl.expand_dims(None, -1)
131131

132132
with pytest.raises(CompilationError) as e:
133-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
133+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
134134

135135
try:
136136
inner = e.value.__cause__
@@ -157,7 +157,7 @@ def kernel():
157157
a = two_returns()
158158
a + tl.arange(0, 4) # only works if we took the first return
159159

160-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
160+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
161161

162162

163163
def test_not_const_annotate_no_err():
@@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
166166
def kernel(N: int = 1):
167167
pass
168168

169-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
169+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
170170

171171

172172
@triton.jit
@@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
186186
a = returns_branched_on_constexpr(N)
187187
a + tl.arange(0, 4)
188188

189-
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0}))
189+
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))
190190

191191
@triton.jit
192192
def kernel2(N: tl.constexpr):
193193
a = returns_branched_on_constexpr(N)
194194
a + tl.arange(0, 8)
195195

196-
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1}))
196+
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))
197197

198198

199199
@triton.jit
@@ -211,7 +211,7 @@ def kernel(N: int):
211211
returns_branched_on_non_constexpr(N)
212212

213213
with pytest.raises(CompilationError) as e:
214-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
214+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
215215

216216
try:
217217
assert "at 2:4:" in str(e.value), "error should point to the function call"
@@ -227,7 +227,7 @@ def kernel():
227227
tl.arange(2, 7)
228228

229229
with pytest.raises(CompilationError) as e:
230-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
230+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
231231
assert str(e.value.__cause__) == "arange's range must be a power of 2"
232232

233233

@@ -238,7 +238,7 @@ def kernel():
238238
tl.full((33, ), 0, dtype=tl.int64)
239239

240240
with pytest.raises(CompilationError) as e:
241-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
241+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
242242
assert str(e.value.__cause__) == "Shape element 0 must be a power of 2"
243243

244244

@@ -251,7 +251,7 @@ def kernel():
251251
a = CAPTURED # noqa
252252

253253
with pytest.raises(CompilationError) as e:
254-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
254+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
255255
assert "CAPTURED is not defined" in str(e.value)
256256

257257

@@ -265,7 +265,7 @@ def kernel():
265265
a = GLOBAL # noqa
266266

267267
with pytest.raises(CompilationError) as e:
268-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
268+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
269269
assert "global variable" in str(e.value)
270270

271271

@@ -279,7 +279,7 @@ def kernel():
279279
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa
280280

281281
# No error.
282-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
282+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
283283

284284

285285
CONSTEXPR_GLOBAL = tl.constexpr(42)
@@ -292,7 +292,7 @@ def kernel():
292292
a = CONSTEXPR_GLOBAL # noqa
293293

294294
# No error.
295-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
295+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
296296

297297

298298
TYPE_ALIAS = tl.pointer_type(tl.int32)
@@ -305,7 +305,7 @@ def kernel():
305305
a = TYPE_ALIAS # noqa
306306

307307
# No error.
308-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
308+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
309309

310310

311311
def test_global_access_in_fn_default_arg():
@@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
315315
pass
316316

317317
# No error.
318-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={}))
318+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))
319319

320320

321321
def test_defaults_assign_no_err():
@@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
324324
def kernel(a=1, B: tl.constexpr = ""):
325325
pass
326326

327-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""}))
327+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))
328328

329329

330330
def test_where_warning(fresh_triton_cache):
@@ -337,7 +337,7 @@ def kernel():
337337
tl.where(a, b, c)
338338

339339
with pytest.warns(UserWarning):
340-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
340+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
341341

342342

343343
@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
@@ -371,8 +371,7 @@ def dtype_kernel(dtype: tl.constexpr):
371371
ctx = pytest.raises(CompilationError, match="")
372372

373373
with ctx as e:
374-
triton.compile(
375-
triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))
374+
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))
376375

377376
if dtype not in supported_dtypes:
378377
try:
@@ -391,7 +390,7 @@ def dot_kernel():
391390
tl.dot(a, b, max_num_imprecise_acc=128)
392391

393392
with pytest.raises(CompilationError) as e:
394-
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={}))
393+
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={}))
395394
try:
396395
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
397396
except AssertionError as assertion_err:

python/test/unit/language/test_core.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4407,17 +4407,15 @@ def kernel(x):
44074407
def test_value_specialization(value: int, value_type: str, device) -> None:
44084408

44094409
def repr(specialization):
4410-
ty = specialization.signature["value1"]
4411-
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
4412-
return f"kernel_{ty}_{cst}"
4410+
spec_type = specialization.signature["VALUE"]
4411+
return f"kernel_{spec_type}"
44134412

44144413
@triton.jit(repr=repr)
4415-
def kernel(value1, is_one, X):
4414+
def kernel(VALUE, X):
44164415
pass
44174416

44184417
x = torch.tensor([3.14159], device=device)
4419-
h = kernel[(1, )](value, 1, x)
4420-
assert "is_one" in h.name
4418+
h = kernel[(1, )](value, x)
44214419
assert value_type in h.name
44224420

44234421

@@ -6188,19 +6186,6 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r
61886186
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))
61896187

61906188

6191-
def test_dtype(device):
6192-
6193-
@triton.jit
6194-
def kernel(X):
6195-
dtype_x: tl.constexpr = X.dtype.element_ty
6196-
tl.static_assert(dtype_x == tl.int32)
6197-
tl.static_assert(dtype_x == tl.constexpr(tl.int32))
6198-
tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32))
6199-
6200-
X = torch.zeros(1, dtype=torch.int32, device=device)
6201-
kernel[(1, )](X)
6202-
6203-
62046189
def test_side_effectful_scan(device):
62056190
if device != "cuda":
62066191
pytest.xfail()

python/test/unit/language/test_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def kernel():
2323
pass
2424

2525
try:
26-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
26+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
2727
except Exception as e:
2828
pytest.fail(f"triton compile failed with error: {e}")
2929

python/test/unit/language/test_tuple.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

python/test/unit/runtime/test_bindings.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,15 @@ def walk_fn(op):
6363
backend = triton.compiler.compiler.make_backend(target)
6464
src = triton.compiler.compiler.ASTSource(
6565
fn=kernel,
66-
signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
67-
for i, arg in enumerate(args)},
68-
constexprs={kernel.arg_names[i]: arg
69-
for i, arg in enumerate(args)
70-
if not isinstance(arg, torch.Tensor)},
71-
attrs=backend.get_attrs_descriptor(kernel.params, args),
66+
signature={
67+
kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
68+
for i, arg in enumerate(args)
69+
if i not in kernel.constexprs
70+
},
71+
constants={kernel.arg_names[i]: arg
72+
for i, arg in enumerate(args)
73+
if not isinstance(arg, torch.Tensor)},
74+
attrs=backend.get_attrs_descriptor(args, kernel.params),
7275
)
7376

7477
context = triton._C.libtriton.ir.context()

0 commit comments

Comments
 (0)