@@ -246,8 +246,7 @@ def shared_memory_cast_kernel():
246
246
layout_a : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = False , element_bitwidth = 8 ,
247
247
rank = 2 )
248
248
layout_T : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = True , element_bitwidth = 8 ,
249
- rank = 2 , ctas_per_cga = [1 , 1 ], cta_split_num = [1 ,
250
- 1 ], cta_order = [1 , 0 ])
249
+ rank = 2 )
251
250
smem = ttgl .allocate_shared_memory (ttgl .int8 , [2 , 256 , 128 ], layout_a )
252
251
perm = smem .index (0 ).permute ((1 , 0 ))
253
252
ttgl .static_assert (perm .type .layout == layout_T )
@@ -613,10 +612,10 @@ def kernel():
613
612
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
614
613
tt.func public @kernel() attributes {noinline = false} {
615
614
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
616
- tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None )_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
615
+ tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_ )_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
617
616
tt.return
618
617
}
619
- tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None )_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
618
+ tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS [32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_ )_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
620
619
tt.return
621
620
}
622
621
}
@@ -855,7 +854,7 @@ def test_tensor_permute():
855
854
a = ttgl .full ([32 , 16 ], 0 , ttgl .int32 , layout = layout )
856
855
# CHECK: tt.trans{{.*}} : tensor<32x16xi32, [[BLOCKED]]> -> tensor<16x32xi32, [[BLOCKED1]]>
857
856
res = ttgl .permute (a , [1 , 0 ])
858
- permuted_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 1 ], [8 , 4 ], [1 , 4 ], [0 , 1 ], [ 1 , 1 ], [ 1 , 1 ], [ 1 , 0 ] )
857
+ permuted_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 1 ], [8 , 4 ], [1 , 4 ], [0 , 1 ])
859
858
ttgl .static_assert (permuted_layout == res .type .layout )
860
859
861
860
@@ -869,7 +868,7 @@ def test_split_join():
869
868
b = ttgl .full ([128 ], 2 , ttgl .int32 , layout )
870
869
# CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]>
871
870
res = ttgl .join (a , b )
872
- expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 2 ], [32 , 1 ], [4 , 1 ], [1 , 0 ], [ 1 , 1 ], [ 1 , 1 ], [ 1 , 0 ] )
871
+ expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 , 2 ], [32 , 1 ], [4 , 1 ], [1 , 0 ])
873
872
ttgl .static_assert (res .type .layout == expect_layout )
874
873
875
874
# CHECK: tt.split {{.*}} : tensor<128x2xi32, [[BLOCKED1]]> -> tensor<128xi32, [[BLOCKED]]>
@@ -878,6 +877,17 @@ def test_split_join():
878
877
ttgl .static_assert (d .type .layout == layout )
879
878
880
879
880
+ @filecheck_test
881
+ @gluon .jit
882
+ def test_reshape_linear_layout ():
883
+ # CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
884
+ # CHECK: [[LINEAR:#.*]] = #ttg.linear
885
+ layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [32 , 1 ], [4 , 1 ], [0 , 1 ])
886
+ x = ttgl .full ([128 , 1 ], 1 , ttgl .int32 , layout = layout )
887
+ # CHECK: tt.reshape %{{.*}} : tensor<128x1xi32, [[BLOCKED]]> -> tensor<128xi32, [[LINEAR]]>
888
+ x .reshape ([128 ])
889
+
890
+
881
891
@filecheck_test
882
892
@gluon .jit
883
893
def test_tensor_reshape ():
@@ -887,8 +897,7 @@ def test_tensor_reshape():
887
897
a = ttgl .full ([256 ], 1 , ttgl .int32 , layout )
888
898
# CHECK: tt.reshape {{.*}} : tensor<256xi32, [[BLOCKED]]> -> tensor<8x4x8xi32, [[BLOCKED1]]>
889
899
v = a .reshape ([8 , 4 , 8 ])
890
- expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 , 2 ], [2 , 4 , 4 ], [4 , 1 , 1 ], [2 , 1 , 0 ], [1 , 1 , 1 ], [1 , 1 , 1 ],
891
- [2 , 1 , 0 ])
900
+ expect_layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 , 2 ], [2 , 4 , 4 ], [4 , 1 , 1 ], [2 , 1 , 0 ])
892
901
ttgl .static_assert (v .type .layout == expect_layout )
893
902
894
903
0 commit comments