Skip to content

Commit f81f19a

Browse files
[release/3.4] "[BACKEND] support tt::TransOp in comesFromLoadOrBlockArg (triton-lang#7343)" (triton-lang#7346)
This patches a bug (pytorch/pytorch#156028) which was introduced by triton-lang#7066. PromoteLHSToTMem.cpp, and is intended to extend it to support all MemDescViewTrait ops. However, in this refactor, support for tt.TransOp was dropped, changing the behavior of AccelerateMatmul.cpp. This PR adds tt::TransOp back into the set of ops supported by comesFromLoadOrBlockArg. i.e.: * behavior before triton-lang#7066: comesFromLoadOrBlockArg tracks loads past: ttg::ConvertLayoutOp, **tt::TransOp** * behavior after triton-lang#7066: comesFromLoadOrBlockArg tracks loads past: ttg::ConvertLayoutOp, ttg::MemDescSubview, ttg::MemDescTransOp, ttg::MemDescReshapeOp, ttg::MemDescReinterpretOp * behavior after this PR: comesFromLoadOrBlockArg tracks loads past: ttg::ConvertLayoutOp, **tt::TransOp** ttg::MemDescSubview, ttg::MemDescTransOp, ttg::MemDescReshapeOp, ttg::MemDescReinterpretOp
1 parent a9ccbac commit f81f19a

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,10 @@ bool comesFromLoadOrBlockArg(Value v) {
15681568
v = cvtOp.getSrc();
15691569
continue;
15701570
}
1571+
if (auto transOp = dyn_cast<tt::TransOp>(def)) {
1572+
v = transOp.getSrc();
1573+
continue;
1574+
}
15711575
if (def->hasTrait<OpTrait::MemDescViewTrait>()) {
15721576
v = def->getOperand(0);
15731577
continue;

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
566566
tt.return
567567
}
568568
}
569+
570+
// -----
571+
572+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
573+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 4], order = [1, 0]}>
574+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
575+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
576+
// CHECK-LABEL: identify_load_then_trans
577+
tt.func public @identify_load_then_trans(
578+
%arg0: !tt.tensordesc<tensor<128x128xf16>>,
579+
%arg1: !tt.tensordesc<tensor<128x128xf16>>,
580+
%arg2: i32,
581+
%arg3: i32,
582+
%arg4: i32,
583+
%arg5: tensor<128x128xf32, #blocked>
584+
) -> tensor<128x128xf32, #blocked> {
585+
// CHECK: %[[DESC0:.*]] = tt.descriptor_load %arg0
586+
// CHECK: %[[DESC1:.*]] = tt.descriptor_load %arg1
587+
%13 = tt.descriptor_load %arg0[%arg4, %arg2] : !tt.tensordesc<tensor<128x128xf16>> -> tensor<128x128xf16, #blocked2>
588+
%14 = tt.descriptor_load %arg1[%arg3, %arg4] : !tt.tensordesc<tensor<128x128xf16>> -> tensor<128x128xf16, #blocked2>
589+
// CHECK: %[[TRANS0:.*]] = tt.trans %[[DESC0]]
590+
// CHECK: %[[ALLOC0:.*]] = ttg.local_alloc %[[TRANS0]]
591+
%15 = tt.trans %13 {order = array<i32: 1, 0>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked3>
592+
// CHECK: %[[TRANS1:.*]] = tt.trans %[[DESC1]]
593+
// CHECK: %[[ALLOC1:.*]] = ttg.local_alloc %[[TRANS1]]
594+
%16 = tt.trans %14 {order = array<i32: 1, 0>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked3>
595+
%17 = ttg.convert_layout %15 : tensor<128x128xf16, #blocked3> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
596+
%18 = ttg.convert_layout %16 : tensor<128x128xf16, #blocked3> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
597+
// CHECK: ttng.warp_group_dot %[[ALLOC0]], %[[ALLOC1]]
598+
%19 = tt.dot %17, %18, %arg5, inputPrecision = tf32 : tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked>
599+
tt.return %19 : tensor<128x128xf32, #blocked>
600+
}
601+
}

0 commit comments

Comments
 (0)