@@ -256,3 +256,23 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
256256 tt.return %r : tensor <128 x64 xf32 , #mma >
257257 }
258258}
259+
260+ // -----
261+
262+ #blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
263+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
264+ #mma = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 64 , 16 ]}>
265+ #shared = #triton_gpu.shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ], hasLeadingOffset = true }>
266+ #shared1 = #triton_gpu.shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ], hasLeadingOffset = true }>
267+ module attributes {" triton_gpu.target" = " cuda:90" , " triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , " triton_gpu.threads-per-warp" = 32 : i32 } {
268+ // CHECK-LABEL: mma_reorder_transpose
269+ // CHECK: triton_gpu.local_alloc
270+ // CHECK: triton_gpu.memdesc_trans
271+ // CHECK: triton_nvidia_gpu.warp_group_dot
272+ tt.func @mma_reorder_transpose (%t: tensor <64 x128 xf16 , #blocked1 >, %dotb: !triton_gpu.memdesc <64 x64 xf16 , #shared >, %dotc: tensor <128 x64 xf32 , #mma >) -> tensor <128 x64 xf32 , #mma >{
273+ %a = tt.trans %t {order = array<i32 : 1 , 0 >} : tensor <64 x128 xf16 , #blocked1 > -> tensor <128 x64 xf16 , #blocked >
274+ %dota = triton_gpu.local_alloc %a: (tensor <128 x64 xf16 , #blocked >) -> !triton_gpu.memdesc <128 x64 xf16 , #shared1 >
275+ %r = triton_nvidia_gpu.warp_group_dot %dota , %dotb , %dotc : !triton_gpu.memdesc <128 x64 xf16 , #shared1 > * !triton_gpu.memdesc <64 x64 xf16 , #shared > -> tensor <128 x64 xf32 , #mma >
276+ tt.return %r : tensor <128 x64 xf32 , #mma >
277+ }
278+ }
0 commit comments