|
13 | 13 | from triton._filecheck import filecheck_test, run_parser
|
14 | 14 | import triton.language as tl
|
15 | 15 | from triton._internal_testing import is_cuda
|
16 |
| -from triton.compiler.errors import CompilationError |
| 16 | +from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure |
17 | 17 |
|
18 | 18 | TARGET_PAT = re.compile('ttg.target = "[^"]*"')
|
19 | 19 |
|
@@ -604,10 +604,10 @@ def kernel():
|
604 | 604 | module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
|
605 | 605 | tt.func public @kernel() attributes {noinline = false} {
|
606 | 606 | %0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
|
607 |
| - tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> () |
| 607 | + tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> () |
608 | 608 | tt.return
|
609 | 609 | }
|
610 |
| - tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} { |
| 610 | + tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} { |
611 | 611 | tt.return
|
612 | 612 | }
|
613 | 613 | }
|
@@ -855,7 +855,7 @@ def test_tensor_permute():
|
855 | 855 | def test_split_join():
|
856 | 856 | # CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
857 | 857 | # CHECK: [[BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
|
858 |
| - layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0]) |
| 858 | + layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], [1], [1], [0]) |
859 | 859 | a = ttgl.full([128], 1, ttgl.int32, layout)
|
860 | 860 | b = ttgl.full([128], 2, ttgl.int32, layout)
|
861 | 861 | # CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]>
|
@@ -883,6 +883,16 @@ def test_tensor_reshape():
|
883 | 883 | ttgl.static_assert(v.type.layout == expect_layout)
|
884 | 884 |
|
885 | 885 |
|
| 886 | +@gluon.jit |
| 887 | +def static_assert_kernel(): |
| 888 | + ttgl.static_assert(False) |
| 889 | + |
| 890 | + |
| 891 | +def test_static_assert(): |
| 892 | + with pytest.raises(CompileTimeAssertionFailure): |
| 893 | + run_parser(static_assert_kernel) |
| 894 | + |
| 895 | + |
886 | 896 | @filecheck_test
|
887 | 897 | @gluon.jit
|
888 | 898 | def test_zeros():
|
|
0 commit comments