Skip to content

Commit de7301e

Browse files
[BACKEND] support tt::TransOp in comesFromLoadOrBlockArg (#7343)
This patches a bug (pytorch/pytorch#156028) which was introduced by #7066. #7066 refactors comesFromLoadOrBlockArg so it can be used in PromoteLHSToTMem.cpp, and is intended to extend it to support all MemDescViewTrait ops. However, in this refactor, support for tt.TransOp was dropped, changing the behavior of AccelerateMatmul.cpp. This PR adds tt::TransOp back into the set of ops supported by comesFromLoadOrBlockArg. i.e.: * behavior before #7066: comesFromLoadOrBlockArg tracks loads past: ttg::ConvertLayoutOp, **tt::TransOp** * behavior after #7066: comesFromLoadOrBlockArg tracks loads past: ttg::ConvertLayoutOp, ttg::MemDescSubview, ttg::MemDescTransOp, ttg::MemDescReshapeOp, ttg::MemDescReinterpretOp * behavior after this PR: comesFromLoadOrBlockArg tracks loads past: ttg::ConvertLayoutOp, **tt::TransOp** ttg::MemDescSubview, ttg::MemDescTransOp, ttg::MemDescReshapeOp, ttg::MemDescReinterpretOp
1 parent 21d2ef2 commit de7301e

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,10 @@ bool comesFromLoadOrBlockArg(Value v) {
15731573
v = cvtOp.getSrc();
15741574
continue;
15751575
}
1576+
if (auto transOp = dyn_cast<tt::TransOp>(def)) {
1577+
v = transOp.getSrc();
1578+
continue;
1579+
}
15761580
if (def->hasTrait<OpTrait::MemDescViewTrait>()) {
15771581
v = def->getOperand(0);
15781582
continue;

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
566566
tt.return
567567
}
568568
}
569+
570+
// -----
571+
572+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
573+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 4], order = [1, 0]}>
574+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
575+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
576+
// CHECK-LABEL: identify_load_then_trans
577+
tt.func public @identify_load_then_trans(
578+
%arg0: !tt.tensordesc<tensor<128x128xf16>>,
579+
%arg1: !tt.tensordesc<tensor<128x128xf16>>,
580+
%arg2: i32,
581+
%arg3: i32,
582+
%arg4: i32,
583+
%arg5: tensor<128x128xf32, #blocked>
584+
) -> tensor<128x128xf32, #blocked> {
585+
// CHECK: %[[DESC0:.*]] = tt.descriptor_load %arg0
586+
// CHECK: %[[DESC1:.*]] = tt.descriptor_load %arg1
587+
%13 = tt.descriptor_load %arg0[%arg4, %arg2] : !tt.tensordesc<tensor<128x128xf16>> -> tensor<128x128xf16, #blocked2>
588+
%14 = tt.descriptor_load %arg1[%arg3, %arg4] : !tt.tensordesc<tensor<128x128xf16>> -> tensor<128x128xf16, #blocked2>
589+
// CHECK: %[[TRANS0:.*]] = tt.trans %[[DESC0]]
590+
// CHECK: %[[ALLOC0:.*]] = ttg.local_alloc %[[TRANS0]]
591+
%15 = tt.trans %13 {order = array<i32: 1, 0>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked3>
592+
// CHECK: %[[TRANS1:.*]] = tt.trans %[[DESC1]]
593+
// CHECK: %[[ALLOC1:.*]] = ttg.local_alloc %[[TRANS1]]
594+
%16 = tt.trans %14 {order = array<i32: 1, 0>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked3>
595+
%17 = ttg.convert_layout %15 : tensor<128x128xf16, #blocked3> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
596+
%18 = ttg.convert_layout %16 : tensor<128x128xf16, #blocked3> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
597+
// CHECK: ttng.warp_group_dot %[[ALLOC0]], %[[ALLOC1]]
598+
%19 = tt.dot %17, %18, %arg5, inputPrecision = tf32 : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked>
599+
tt.return %19 : tensor<128x128xf32, #blocked>
600+
}
601+
}

0 commit comments

Comments
 (0)