@@ -246,8 +246,7 @@ def shared_memory_cast_kernel():
246246 layout_a : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = False , element_bitwidth = 8 ,
247247 rank = 2 )
248248 layout_T : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = True , element_bitwidth = 8 ,
249- rank = 2 , ctas_per_cga = [1 , 1 ], cta_split_num = [1 ,
250- 1 ], cta_order = [1 , 0 ])
249+ rank = 2 )
251250 smem = ttgl .allocate_shared_memory (ttgl .int8 , [2 , 256 , 128 ], layout_a )
252251 perm = smem .index (0 ).permute ((1 , 0 ))
253252 ttgl .static_assert (perm .type .layout == layout_T )
@@ -613,10 +612,10 @@ def kernel():
613612module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
614613 tt.func public @kernel() attributes {noinline = false} {
615614 %0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
616- tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None )_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
615+ tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_ )_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
617616 tt.return
618617 }
619- tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None )_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
618+ tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_ )_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
620619 tt.return
621620 }
622621}
@@ -855,7 +854,7 @@ def test_tensor_permute():
855854 a = ttgl .full ([32 , 16 ], 0 , ttgl .int32 , layout = layout )
856855 # CHECK: tt.trans{{.*}} : tensor<32x16xi32, [[BLOCKED]]> -> tensor<16x32xi32, [[BLOCKED1]]>
857856 res = ttgl .permute (a , [1 , 0 ])
858- permuted_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 1 ], [8 , 4 ], [1 , 4 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 1 , 0 ] )
857+ permuted_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 1 ], [8 , 4 ], [1 , 4 ], [0 , 1 ])
859858 ttgl .static_assert (permuted_layout == res .type .layout )
860859
861860
@@ -869,7 +868,7 @@ def test_split_join():
869868 b = ttgl .full ([128 ], 2 , ttgl .int32 , layout )
870869 # CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]>
871870 res = ttgl .join (a , b )
872- expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 2 ], [32 , 1 ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 1 , 0 ] )
871+ expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 2 ], [32 , 1 ], [4 , 1 ], [1 , 0 ])
873872 ttgl .static_assert (res .type .layout == expect_layout )
874873
875874 # CHECK: tt.split {{.*}} : tensor<128x2xi32, [[BLOCKED1]]> -> tensor<128xi32, [[BLOCKED]]>
@@ -878,6 +877,17 @@ def test_split_join():
878877 ttgl .static_assert (d .type .layout == layout )
879878
880879
880+ @filecheck_test
881+ @gluon .jit
882+ def test_reshape_linear_layout ():
883+ # CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
884+ # CHECK: [[LINEAR:#.*]] = #ttg.linear
885+ layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [32 , 1 ], [4 , 1 ], [0 , 1 ])
886+ x = ttgl .full ([128 , 1 ], 1 , ttgl .int32 , layout = layout )
887+ # CHECK: tt.reshape %{{.*}} : tensor<128x1xi32, [[BLOCKED]]> -> tensor<128xi32, [[LINEAR]]>
888+ x .reshape ([128 ])
889+
890+
881891@filecheck_test
882892@gluon .jit
883893def test_tensor_reshape ():
@@ -887,8 +897,7 @@ def test_tensor_reshape():
887897 a = ttgl .full ([256 ], 1 , ttgl .int32 , layout )
888898 # CHECK: tt.reshape {{.*}} : tensor<256xi32, [[BLOCKED]]> -> tensor<8x4x8xi32, [[BLOCKED1]]>
889899 v = a .reshape ([8 , 4 , 8 ])
890- expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 , 2 ], [2 , 4 , 4 ], [4 , 1 , 1 ], [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ],
891- [2 , 1 , 0 ])
900+ expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 , 2 ], [2 , 4 , 4 ], [4 , 1 , 1 ], [2 , 1 , 0 ])
892901 ttgl .static_assert (v .type .layout == expect_layout )
893902
894903
0 commit comments