Skip to content

Commit 98b957d

Browse files
authored
[FRONTEND] [BC Breaking] Require global variables to be insantiated as constexpr ob… (#5961)
1 parent 20874dd commit 98b957d

File tree

4 files changed

+17
-16
lines changed

4 files changed

+17
-16
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,11 @@ def kernel():
300300
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa
301301

302302
# No error.
303-
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
303+
try:
304+
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
305+
assert False, "Using a constexpr annotated global variable should not be allowed"
306+
except CompilationError as e:
307+
assert "Cannot access global variable" in str(e)
304308

305309

306310
CONSTEXPR_GLOBAL = tl.constexpr(42)

python/test/unit/language/test_random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def random_raw(self):
110110
# Unit Tests
111111
#####################################
112112

113-
BLOCK: tl.constexpr = 1024
113+
BLOCK = tl.constexpr(1024)
114114

115115
# test generation of random uint32
116116

@@ -144,7 +144,7 @@ def const_kernel(X, N, seed: tl.constexpr):
144144
# triton result
145145
x = torch.empty(size, dtype=torch_dtype, device=device)
146146
N = x.numel()
147-
grid = (triton.cdiv(N, BLOCK), )
147+
grid = (triton.cdiv(N, BLOCK.value), )
148148
if const_seed:
149149
const_kernel[grid](x, N, seed=seed)
150150
else:
@@ -184,7 +184,7 @@ def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr):
184184
# triton result
185185
x = torch.empty(size, dtype=torch.float32, device=device)
186186
N = x.numel()
187-
grid = (triton.cdiv(N, BLOCK), )
187+
grid = (triton.cdiv(N, BLOCK.value), )
188188
if const_seed:
189189
const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype))
190190
else:
@@ -238,7 +238,7 @@ def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr):
238238
# triton result
239239
x = torch.empty(size, dtype=torch.float32, device=device)
240240
N = x.numel()
241-
grid = (triton.cdiv(N, BLOCK), )
241+
grid = (triton.cdiv(N, BLOCK.value), )
242242
if const_seed:
243243
const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype))
244244
else:

python/test/unit/runtime/test_cache.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
223223
assert len(kernel.device_caches[device][0]) == 1
224224

225225

226-
GLOBAL_VAR: tl.constexpr = 1
226+
GLOBAL_VAR = tl.constexpr(1)
227227

228228

229229
def test_kernel_global_var_change(device):
@@ -263,7 +263,7 @@ def kernel():
263263
kernel[(1, )]()
264264

265265

266-
CONSTEXPR_GLOBAL: tl.constexpr = 42
266+
CONSTEXPR_GLOBAL = tl.constexpr(42)
267267

268268

269269
def test_local_does_not_shadow_global():
@@ -274,9 +274,9 @@ def kernel():
274274
a = CONSTEXPR_GLOBAL # noqa
275275
_, CONSTEXPR_GLOBAL = 0, 0 # noqa
276276

277-
CONSTEXPR_GLOBAL = 42
277+
CONSTEXPR_GLOBAL = tl.constexpr(42)
278278
kernel[(1, )]()
279-
CONSTEXPR_GLOBAL = 43
279+
CONSTEXPR_GLOBAL = tl.constexpr(43)
280280

281281
# Error because the `CONSTEXPR_GLOBAL` we're modifying is the same
282282
# `CONSTEXPR_GLOBAL` that's read inside `kernel`. (Alternatively, we could
@@ -288,7 +288,7 @@ def kernel():
288288
kernel[(1, )]()
289289

290290

291-
CONFLICTING_GLOBAL: tl.constexpr = 0
291+
CONFLICTING_GLOBAL = tl.constexpr(0)
292292

293293

294294
@triton.jit

python/triton/compiler/code_generator.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .._C.libtriton import ir
1313
from ..language import constexpr, semantic, str_to_ty, tensor
1414
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, base_value, base_type
15-
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
15+
from ..runtime.jit import get_jit_fn_file_line
1616
# ideally we wouldn't need any runtime component
1717
from ..runtime import JITFunction
1818
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
@@ -347,9 +347,6 @@ def _is_constexpr_global(self, name):
347347
if _is_constexpr(val):
348348
return True
349349

350-
if a := self.gscope.get("__annotations__", {}).get(name):
351-
return _normalize_ty(a) == "constexpr"
352-
353350
return False
354351

355352
def _is_namedtuple(self, val):
@@ -386,8 +383,8 @@ def global_lookup(name: str, absent):
386383
textwrap.dedent(f"""\
387384
Cannot access global variable {name} from within @jit'ed
388385
function. Triton kernels can only access global variables that
389-
are annotated as constexpr (`x: triton.language.constexpr = 42`
390-
or `x = triton.language.constexpr(42)`). Alternatively, set the
386+
are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from
387+
annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the
391388
envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
392389
promise to support this forever.""").replace("\n", " "))
393390

0 commit comments

Comments
 (0)