@@ -307,25 +307,25 @@ def anchor(x):
307307@filecheck_test
308308@gluon .jit
309309def test_warp_specialize ():
310- # CHECK-LABEL: tt.func public @ test_warp_specialize
310+ # CHECK-LABEL: test_warp_specialize
311311 # CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
312312 # CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
313313 # CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
314314 # CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
315315 # CHECK-NEXT: default {
316- # CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @" warp_specialize_default{{.*}}" ([[A]], [[B]], [[C]])
316+ # CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}} warp_specialize_default{{.*}}([[A]], [[B]], [[C]])
317317 # CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
318318 # CHECK-NEXT: }
319319 # CHECK-NEXT: partition0(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
320- # CHECK-NEXT: call @" warp_specialize_worker0{{.*}}" (%arg0, %arg1, %arg2)
320+ # CHECK-NEXT: call @{{.*}} warp_specialize_worker0{{.*}}(%arg0, %arg1, %arg2)
321321 # CHECK-NEXT: warp_return
322322 # CHECK-NEXT: }
323323 # CHECK-NEXT: partition1(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>, %arg2: tensor<4xi32>) num_warps(4) {
324- # CHECK-NEXT: call @" warp_specialize_worker1{{.*}}" (%arg0, %arg1, %arg2)
324+ # CHECK-NEXT: call @{{.*}} warp_specialize_worker1{{.*}}(%arg0, %arg1, %arg2)
325325 # CHECK-NEXT: warp_return
326326 # CHECK-NEXT: }
327- # CHECK-NEXT: call @anchor{{.*}}([[OUTS]]#0)
328- # CHECK-NEXT: call @" anchor{{.*}}" ([[OUTS]]#1, [[OUTS]]#2)
327+ # CHECK-NEXT: call @{{.*}} anchor{{.*}}([[OUTS]]#0)
328+ # CHECK-NEXT: call @{{.*}} anchor{{.*}}([[OUTS]]#1, [[OUTS]]#2)
329329 pair = Pair (tl .arange (0 , 1 ), tl .arange (0 , 2 ))
330330 a , b = ttgl .warp_specialize ((pair , tl .arange (0 , 4 )), warp_specialize_default ,
331331 [warp_specialize_worker0 , warp_specialize_worker1 ], [4 , 4 ], [24 , 48 ])
@@ -541,6 +541,29 @@ def kernel():
541541 assert "order must be a permutation of 0..(rank-1), but was [1]" in str (e .value .__cause__ )
542542
543543
544+ @gluon .jit
545+ def tmem_subslice_kernel ():
546+ layout : ttgl .constexpr = ttgl .nvidia .blackwell .TensorMemoryLayout (block = [128 , 128 ], unpacked = True )
547+ tmem = ttgl .nvidia .blackwell .allocate_tensor_memory (ttgl .int32 , [2 , 256 , 256 ], layout )
548+ tmem .subslice (0 )
549+
550+
551+ def test_tmem_subslice_constexpr ():
552+ expecttest .assert_expected_inline (
553+ run_parser (tmem_subslice_kernel ).str_nodebug (), """\
554+ #tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
555+ module {
556+ tt.func public @tmem_subslice_kernel() attributes {noinline = false} {
557+ %result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable>
558+ %c0_i32 = arith.constant 0 : i32
559+ %c0_i32_0 = arith.constant 0 : i32
560+ %0 = ttg.memdesc_subview %result[%c0_i32, %c0_i32_0, %c0_i32_0] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable, 2x256x256>
561+ tt.return
562+ }
563+ }
564+ """ )
565+
566+
544567@gluon .jit
545568def smem_and_layout_user (smem , a : ttgl .constexpr ):
546569 pass
@@ -561,10 +584,10 @@ def kernel():
561584module {
562585 tt.func public @kernel() attributes {noinline = false} {
563586 %0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
564- tt.call @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
587+ tt.call @"test_frontend. smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
565588 tt.return
566589 }
567- tt.func private @"smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
590+ tt.func private @"test_frontend. smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
568591 tt.return
569592 }
570593}
0 commit comments