Skip to content

Commit 4ae95e7

Browse files
authored
[BACKEND] Fix transpose optimization missed during refactor (#5226)
1 parent 82b8f0f commit 4ae95e7

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,13 @@ class FuseTransHopper : public OpRewritePattern<LocalAllocOp> {
326326
return failure();
327327

328328
// Match outerCvt(trans(innerCvt(x))).
329-
auto trans = allocOp.getSrc().getDefiningOp<MemDescTransOp>();
329+
auto trans = allocOp.getSrc().getDefiningOp<TransOp>();
330330
if (!trans || trans.getOrder() != ArrayRef<int32_t>({1, 0}))
331331
return failure();
332332

333333
MemDescType allocType = allocOp.getType();
334334
auto allocEncoding = cast<SharedEncodingAttr>(allocType.getEncoding());
335-
MemDescType srcTy = trans.getSrc().getType();
335+
RankedTensorType srcTy = trans.getSrc().getType();
336336

337337
// MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3
338338
// without transpose for other data types.)

test/TritonGPU/dot-operands.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,23 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
256256
tt.return %r : tensor<128x64xf32, #mma>
257257
}
258258
}
259+
260+
// -----
261+
262+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
263+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
264+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
265+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
266+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
267+
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
268+
// CHECK-LABEL: mma_reorder_transpose
269+
// CHECK: triton_gpu.local_alloc
270+
// CHECK: triton_gpu.memdesc_trans
271+
// CHECK: triton_nvidia_gpu.warp_group_dot
272+
tt.func @mma_reorder_transpose(%t: tensor<64x128xf16, #blocked1>, %dotb: !triton_gpu.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
273+
%a = tt.trans %t {order = array<i32: 1, 0>} : tensor<64x128xf16, #blocked1> -> tensor<128x64xf16, #blocked>
274+
%dota = triton_gpu.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !triton_gpu.memdesc<128x64xf16, #shared1>
275+
%r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !triton_gpu.memdesc<128x64xf16, #shared1> * !triton_gpu.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma>
276+
tt.return %r : tensor<128x64xf32, #mma>
277+
}
278+
}

0 commit comments

Comments
 (0)