@@ -17,7 +17,7 @@ def kernel():
1717 a += 1 # noqa
1818
1919 with pytest .raises (CompilationError ) as e :
20- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
20+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
2121
2222 try :
2323 assert "is not defined" in str (e .value ), "error should mention the undefined variable"
@@ -32,7 +32,7 @@ def kernel():
3232 0 + "a"
3333
3434 with pytest .raises (CompilationError ) as e :
35- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
35+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
3636
3737 try :
3838 assert "at 2:4:" in str (e .value ), "error should point to the 0"
@@ -47,7 +47,7 @@ def kernel():
4747 tl .static_assert (isinstance (0 , tl .tensor ))
4848
4949 with pytest .raises (CompilationError ) as e :
50- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
50+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
5151
5252 try :
5353 assert isinstance (e .value , CompileTimeAssertionFailure )
@@ -66,7 +66,7 @@ def kernel():
6666 not (0 , 0 )
6767
6868 with pytest .raises (CompilationError ) as e :
69- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
69+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
7070
7171 try :
7272 assert e .value .__cause__ is None
@@ -83,7 +83,7 @@ def kernel():
8383 1.0 << 1
8484
8585 with pytest .raises (CompilationError ) as e :
86- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
86+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
8787
8888 try :
8989 assert "at 2:4:" in str (e .value ), "error should point to the 1.0"
@@ -107,7 +107,7 @@ def kernel():
107107 nested_call ()
108108
109109 with pytest .raises (CompilationError ) as e :
110- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
110+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
111111
112112 try :
113113 inner = e .value .__cause__
@@ -130,7 +130,7 @@ def kernel():
130130 tl .expand_dims (None , - 1 )
131131
132132 with pytest .raises (CompilationError ) as e :
133- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
133+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
134134
135135 try :
136136 inner = e .value .__cause__
@@ -157,7 +157,7 @@ def kernel():
157157 a = two_returns ()
158158 a + tl .arange (0 , 4 ) # only works if we took the first return
159159
160- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
160+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
161161
162162
163163def test_not_const_annotate_no_err ():
@@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
166166 def kernel (N : int = 1 ):
167167 pass
168168
169- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'N' : 'i32' }, constants = {}))
169+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'N' : 'i32' }, constexprs = {}))
170170
171171
172172@triton .jit
@@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
186186 a = returns_branched_on_constexpr (N )
187187 a + tl .arange (0 , 4 )
188188
189- triton .compile (triton .compiler .ASTSource (fn = kernel1 , signature = {}, constants = {"N" : 0 }))
189+ triton .compile (triton .compiler .ASTSource (fn = kernel1 , signature = {"N" : "constexpr" }, constexprs = {"N" : 0 }))
190190
191191 @triton .jit
192192 def kernel2 (N : tl .constexpr ):
193193 a = returns_branched_on_constexpr (N )
194194 a + tl .arange (0 , 8 )
195195
196- triton .compile (triton .compiler .ASTSource (fn = kernel2 , signature = {}, constants = {"N" : 1 }))
196+ triton .compile (triton .compiler .ASTSource (fn = kernel2 , signature = {"N" : "constexpr" }, constexprs = {"N" : 1 }))
197197
198198
199199@triton .jit
@@ -211,7 +211,7 @@ def kernel(N: int):
211211 returns_branched_on_non_constexpr (N )
212212
213213 with pytest .raises (CompilationError ) as e :
214- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'N' : 'i32' }, constants = {}))
214+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'N' : 'i32' }, constexprs = {}))
215215
216216 try :
217217 assert "at 2:4:" in str (e .value ), "error should point to the function call"
@@ -227,7 +227,7 @@ def kernel():
227227 tl .arange (2 , 7 )
228228
229229 with pytest .raises (CompilationError ) as e :
230- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
230+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
231231 assert str (e .value .__cause__ ) == "arange's range must be a power of 2"
232232
233233
@@ -238,7 +238,7 @@ def kernel():
238238 tl .full ((33 , ), 0 , dtype = tl .int64 )
239239
240240 with pytest .raises (CompilationError ) as e :
241- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
241+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
242242 assert str (e .value .__cause__ ) == "Shape element 0 must be a power of 2"
243243
244244
@@ -251,7 +251,7 @@ def kernel():
251251 a = CAPTURED # noqa
252252
253253 with pytest .raises (CompilationError ) as e :
254- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
254+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
255255 assert "CAPTURED is not defined" in str (e .value )
256256
257257
@@ -265,7 +265,7 @@ def kernel():
265265 a = GLOBAL # noqa
266266
267267 with pytest .raises (CompilationError ) as e :
268- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
268+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
269269 assert "global variable" in str (e .value )
270270
271271
@@ -279,7 +279,7 @@ def kernel():
279279 a = CONSTEXPR_ANNOTATED_GLOBAL # noqa
280280
281281 # No error.
282- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
282+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
283283
284284
285285CONSTEXPR_GLOBAL = tl .constexpr (42 )
@@ -292,7 +292,7 @@ def kernel():
292292 a = CONSTEXPR_GLOBAL # noqa
293293
294294 # No error.
295- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
295+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
296296
297297
298298TYPE_ALIAS = tl .pointer_type (tl .int32 )
@@ -305,7 +305,7 @@ def kernel():
305305 a = TYPE_ALIAS # noqa
306306
307307 # No error.
308- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
308+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
309309
310310
311311def test_global_access_in_fn_default_arg ():
@@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
315315 pass
316316
317317 # No error.
318- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'a' : "i32" }, constants = {}))
318+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'a' : "i32" }, constexprs = {}))
319319
320320
321321def test_defaults_assign_no_err ():
@@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
324324 def kernel (a = 1 , B : tl .constexpr = "" ):
325325 pass
326326
327- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'a' : 'i32' }, constants = {'B' : "" }))
327+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {'a' : 'i32' , 'B' : 'constexpr' }, constexprs = {'B' : "" }))
328328
329329
330330def test_where_warning (fresh_triton_cache ):
@@ -337,7 +337,7 @@ def kernel():
337337 tl .where (a , b , c )
338338
339339 with pytest .warns (UserWarning ):
340- triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constants = {}))
340+ triton .compile (triton .compiler .ASTSource (fn = kernel , signature = {}, constexprs = {}))
341341
342342
343343@pytest .mark .parametrize ("dtype" , [tl .float8e5 , tl .float8e5b16 , tl .float8e4nv , tl .float8e4b8 , tl .float8e4b15 ])
@@ -371,7 +371,8 @@ def dtype_kernel(dtype: tl.constexpr):
371371 ctx = pytest .raises (CompilationError , match = "" )
372372
373373 with ctx as e :
374- triton .compile (triton .compiler .ASTSource (fn = dtype_kernel , signature = {}, constants = {"dtype" : dtype }))
374+ triton .compile (
375+ triton .compiler .ASTSource (fn = dtype_kernel , signature = {"dtype" : "constexpr" }, constexprs = {"dtype" : dtype }))
375376
376377 if dtype not in supported_dtypes :
377378 try :
@@ -426,7 +427,7 @@ def dot_kernel():
426427 tl .dot (a , b , max_num_imprecise_acc = 128 )
427428
428429 with pytest .raises (CompilationError ) as e :
429- triton .compile (triton .compiler .ASTSource (fn = dot_kernel , signature = {}, constants = {}))
430+ triton .compile (triton .compiler .ASTSource (fn = dot_kernel , signature = {}, constexprs = {}))
430431 try :
431432 assert (str (e .value .__cause__ ) == "max_num_imprecise_acc (128) must be <= K (64)" )
432433 except AssertionError as assertion_err :
0 commit comments