Skip to content

Commit 7ce287d

Browse files
Mogballpeterbell10
andauthored
[Frontend][Gluon] Make static_assert actually work (#7168)
* Added `ttgl.static_assert` (and `static_print`) * Fix binop where RHS is constexpr but LHS is not * Fix `unwrap_if_constexpr` to recursively unwrap --------- Co-authored-by: peterbell10 <[email protected]>
1 parent 219433e commit 7ce287d

File tree

6 files changed

+28
-7
lines changed

6 files changed

+28
-7
lines changed

python/test/gluon/test_frontend.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton._filecheck import filecheck_test, run_parser
1414
import triton.language as tl
1515
from triton._internal_testing import is_cuda
16-
from triton.compiler.errors import CompilationError
16+
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
1717

1818
TARGET_PAT = re.compile('ttg.target = "[^"]*"')
1919

@@ -604,10 +604,10 @@ def kernel():
604604
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
605605
tt.func public @kernel() attributes {noinline = false} {
606606
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
607-
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
607+
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
608608
tt.return
609609
}
610-
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
610+
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
611611
tt.return
612612
}
613613
}
@@ -855,7 +855,7 @@ def test_tensor_permute():
855855
def test_split_join():
856856
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
857857
# CHECK: [[BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
858-
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
858+
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], [1], [1], [0])
859859
a = ttgl.full([128], 1, ttgl.int32, layout)
860860
b = ttgl.full([128], 2, ttgl.int32, layout)
861861
# CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]>
@@ -883,6 +883,16 @@ def test_tensor_reshape():
883883
ttgl.static_assert(v.type.layout == expect_layout)
884884

885885

886+
@gluon.jit
887+
def static_assert_kernel():
888+
ttgl.static_assert(False)
889+
890+
891+
def test_static_assert():
892+
with pytest.raises(CompileTimeAssertionFailure):
893+
run_parser(static_assert_kernel)
894+
895+
886896
@filecheck_test
887897
@gluon.jit
888898
def test_zeros():

python/triton/compiler/code_generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,8 @@ def _apply_binary_method(self, method_name, lhs, rhs):
637637
if _is_triton_tensor(rhs):
638638
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
639639
return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
640+
if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
641+
lhs = constexpr(lhs)
640642
return getattr(lhs, method_name)(rhs)
641643

642644
def visit_BinOp(self, node):
@@ -1457,9 +1459,12 @@ def ret(self, node: ast.Call):
14571459

14581460
return ret
14591461

1462+
from ..experimental.gluon import language as ttgl
14601463
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
14611464
language.core.static_assert: execute_static_assert,
14621465
language.core.static_print: static_executor(print),
1466+
ttgl.static_assert: execute_static_assert,
1467+
ttgl.static_print: static_executor(print),
14631468
int: static_executor(int),
14641469
len: static_executor(len),
14651470
}

python/triton/compiler/compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
1111
from ..runtime.driver import driver
1212
from ..tools.disasm import get_sass
13-
# TODO: this shouldn't be here
14-
from .code_generator import ast_to_ttir
1513
from pathlib import Path
1614
import re
1715
import functools
@@ -81,6 +79,7 @@ def hash(self):
8179
return hashlib.sha256(key.encode("utf-8")).hexdigest()
8280

8381
def make_ir(self, options, codegen_fns, module_map, context):
82+
from .code_generator import ast_to_ttir
8483
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
8584
module_map=module_map)
8685

python/triton/experimental/gluon/_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22
import triton
3-
from triton.compiler.code_generator import ast_to_ttir
43
from triton.compiler.compiler import ASTSource
54
from triton.backends.compiler import Language
65
from triton.runtime.jit import JITFunction
@@ -19,6 +18,7 @@ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
1918

2019
def make_ir(self, options, codegen_fns, module_map, context):
2120
from triton.compiler.compiler import make_backend
21+
from triton.compiler.code_generator import ast_to_ttir
2222

2323
builder = ir.builder(context)
2424
module = builder.create_module()

python/triton/experimental/gluon/language/_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"reshape",
5555
"split",
5656
"static_assert",
57+
"static_print",
5758
"store",
5859
"to_tensor",
5960
"where",

python/triton/language/core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ def wrapper(*args, _semantic=None, **kwargs):
364364

365365

366366
def _unwrap_if_constexpr(o):
367+
if isinstance(o, list):
368+
return [_unwrap_if_constexpr(x) for x in o]
369+
if isinstance(o, builtins.tuple):
370+
return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
371+
if isinstance(o, tuple):
372+
return tuple(_unwrap_if_constexpr(x) for x in o)
367373
return o.value if isinstance(o, constexpr) else o
368374

369375

0 commit comments

Comments
 (0)