@@ -247,8 +247,8 @@ def shared_memory_cast_kernel():
247
247
rank = 2 )
248
248
layout_T : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = True , element_bitwidth = 8 ,
249
249
rank = 2 )
250
- smem = ttgl .allocate_shared_memory (ttgl .int8 , [256 , 128 ], layout_a )
251
- smem .permute ((1 , 0 ), layout_T )
250
+ smem = ttgl .allocate_shared_memory (ttgl .int8 , [2 , 256 , 128 ], layout_a )
251
+ smem .subslice ( 0 ). permute ((1 , 0 ), layout_T )
252
252
253
253
layout_b : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , transposed = False , element_bitwidth = 16 ,
254
254
rank = 4 , cta_order = [3 , 2 , 1 , 0 ])
@@ -271,11 +271,14 @@ def test_shared_memory_cast(fresh_knobs):
271
271
#smem = #ttg.shared_memory
272
272
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
273
273
tt.func public @shared_memory_cast_kernel() attributes {noinline = false} {
274
- %0 = ttg.local_alloc : () -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable>
275
- %1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>
276
- %2 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
277
- %3 = ttg.memdesc_reshape %2 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
278
- %4 = ttg.memdesc_reinterpret %2 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
274
+ %0 = ttg.local_alloc : () -> !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable>
275
+ %c0_i32 = arith.constant 0 : i32
276
+ %c0_i32_0 = arith.constant 0 : i32
277
+ %1 = ttg.memdesc_subview %0[%c0_i32_0, %c0_i32, %c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128>
278
+ %2 = ttg.memdesc_trans %1 {order = array<i32: 1, 0>} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>
279
+ %3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable>
280
+ %4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable, 32x1x4x64>
281
+ %5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable>
279
282
tt.return
280
283
}
281
284
}
0 commit comments