Skip to content

Commit 2dfc748

Browse files
committed
Revert "Revert "[BACKEND] Fix a missed transpose optimization during refactor (#5236)""
This reverts commit d1aaea2.
1 parent 8e8394a commit 2dfc748

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
148148
LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp,
149149
PatternRewriter &rewriter) const override {
150150
// Match outerCvt(trans(innerCvt(x))).
151-
auto trans = cvtOp.getSrc().getDefiningOp<MemDescTransOp>();
151+
auto trans = cvtOp.getSrc().getDefiningOp<TransOp>();
152152
if (!trans || trans.getOrder() != ArrayRef<int32_t>{1, 0})
153153
return failure();
154154

155-
auto srcTy = dyn_cast<RankedTensorType>(trans.getSrc().getType());
155+
RankedTensorType srcTy = trans.getSrc().getType();
156156

157157
if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
158158
srcTy = srcCvt.getSrc().getType();

test/TritonGPU/dot-operands.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,22 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
282282
tt.return %r : tensor<128x64xf32, #mma>
283283
}
284284
}
285+
286+
// -----
287+
288+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
289+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
290+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
291+
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
292+
// CHECK-LABEL: mmav2_reorder_transpose
293+
// CHECK: triton_gpu.local_alloc
294+
// CHECK: triton_gpu.memdesc_trans
295+
// CHECK: triton_gpu.local_load
296+
// CHECK: tt.dot
297+
tt.func @mmav2_reorder_transpose(%t: tensor<32x128xf16, #blocked1>, %dotb: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{
298+
%a = tt.trans %t {order = array<i32: 1, 0>} : tensor<32x128xf16, #blocked1> -> tensor<128x32xf16, #blocked>
299+
%cv = triton_gpu.convert_layout %a : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
300+
%r = tt.dot %cv, %dotb, %dotc, inputPrecision = tf32 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma>
301+
tt.return %r : tensor<128x64xf32, #mma>
302+
}
303+
}

0 commit comments

Comments
 (0)