@@ -282,3 +282,22 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
282282 tt.return %r : tensor <128 x64 xf32 , #mma >
283283 }
284284}
285+
286+ // -----
287+
288+ #blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
289+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
290+ #mma = #triton_gpu.nvidia_mma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 8 ]}>
291+ module attributes {" triton_gpu.target" = " cuda:90" , " triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , " triton_gpu.threads-per-warp" = 32 : i32 } {
292+ // CHECK-LABEL: mmav2_reorder_transpose
293+ // CHECK: triton_gpu.local_alloc
294+ // CHECK: triton_gpu.memdesc_trans
295+ // CHECK: triton_gpu.local_load
296+ // CHECK: tt.dot
297+ tt.func @mmav2_reorder_transpose (%t: tensor <32 x128 xf16 , #blocked1 >, %dotb: tensor <32 x64 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>, %dotc: tensor <128 x64 xf32 , #mma >) -> tensor <128 x64 xf32 , #mma >{
298+ %a = tt.trans %t {order = array<i32 : 1 , 0 >} : tensor <32 x128 xf16 , #blocked1 > -> tensor <128 x32 xf16 , #blocked >
299+ %cv = triton_gpu.convert_layout %a : tensor <128 x32 xf16 , #blocked > -> tensor <128 x32 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
300+ %r = tt.dot %cv , %dotb , %dotc , inputPrecision = tf32 : tensor <128 x32 xf16 , #triton_gpu.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x64 xf16 , #triton_gpu.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x64 xf32 , #mma >
301+ tt.return %r : tensor <128 x64 xf32 , #mma >
302+ }
303+ }
0 commit comments