Skip to content

Commit 7e20e48

Browse files
etiottoCopilot
andauthored
[Coalescing]: Support layout propagation through scf.if nested in a loop (#4868)
This PR fixes layout propagation through scf.if operations nested within loops in the Triton Intel GPU coalescing pass. Fixes #4867 --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 4f8134b commit 7e20e48

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

test/TritonIntelGPU/coalesce.mlir

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
5454
// -----
5555

5656
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
57-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
58-
57+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
5958

6059
// CHECK: [[NARROW_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
6160
// CHECK: [[WIDE_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -343,7 +342,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
343342

344343
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
345344
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
346-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 32 : i32} {
345+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32} {
347346
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
348347
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
349348
// CHECK: @triton_red_fused_mul_sum_0
@@ -412,7 +411,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
412411
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}>
413412
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
414413
#blocked2 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
415-
module attributes {ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
414+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
416415
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}>
417416
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
418417
// CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
@@ -474,7 +473,7 @@ module attributes {ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.n
474473
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
475474
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
476475
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 4], warpsPerCTA = [1, 2, 1], order = [0, 1, 2]}>
477-
module attributes {ttig.min_sg_size = 16 : i32, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
476+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32} {
478477
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
479478
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
480479
// CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
@@ -587,3 +586,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
587586
tt.return
588587
}
589588
}
589+
590+
// -----
591+
592+
// COM: Test layout propagation for nested operations (scf.if nested in scf.for).
593+
// COM: Reproducer for issue #4867
594+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
595+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
596+
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
597+
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
598+
// CHECK: @test_4867
599+
tt.func public @test_4867(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i1) {
600+
%c0_i32 = arith.constant 0 : i32
601+
%c16_i32 = arith.constant 16 : i32
602+
%c128_i64 = arith.constant 128 : i64
603+
%c1_i64 = arith.constant 1 : i64
604+
%c32_i32 = arith.constant 32 : i32
605+
%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<128x32xf32, #blocked>>
606+
%1 = tt.make_tensor_ptr %arg1, [%c128_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x128xf32, #blocked>>
607+
%2:2 = scf.for %arg3 = %c0_i32 to %c32_i32 step %c32_i32 iter_args(%arg4 = %0, %arg5 = %1) -> (!tt.ptr<tensor<128x32xf32, #blocked>>, !tt.ptr<tensor<32x128xf32, #blocked>>) : i32 {
608+
// CHECK: scf.for {{.*}}
609+
// CHECK-NOT: [[BLOCKED_LAYOUT]]>>
610+
%adv = tt.advance %arg5, [%c32_i32, %c0_i32] : <tensor<32x128xf32, #blocked>>
611+
%3:2 = scf.if %arg2 -> (!tt.ptr<tensor<32x128xf32, #blocked>>, !tt.ptr<tensor<32x128xf32, #blocked>>) {
612+
scf.yield %adv, %arg5 : !tt.ptr<tensor<32x128xf32, #blocked>>, !tt.ptr<tensor<32x128xf32, #blocked>>
613+
} else {
614+
scf.yield %arg5, %adv : !tt.ptr<tensor<32x128xf32, #blocked>>, !tt.ptr<tensor<32x128xf32, #blocked>>
615+
}
616+
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x32xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<32x128xf32, [[BLOCKED_LAYOUT1]]>>
617+
scf.yield %arg4, %3#0 : !tt.ptr<tensor<128x32xf32, #blocked>>, !tt.ptr<tensor<32x128xf32, #blocked>>
618+
}
619+
// CHECK: [[ADV:%.*]] = tt.advance {{.*}} : <tensor<128x32xf32, [[BLOCKED_LAYOUT]]>>
620+
%3 = tt.advance %2#0, [%c0_i32, %c16_i32] : <tensor<128x32xf32, #blocked>>
621+
// CHECK: [[LOAD:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<32x128xf32, [[BLOCKED_LAYOUT1]]>>
622+
%4 = tt.load %1 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x128xf32, #blocked>>
623+
tt.return
624+
}
625+
}

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1515
#include "triton/Tools/StrUtil.h"
1616
#include "llvm/ADT/STLExtras.h"
17+
#include "llvm/ADT/TypeSwitch.h"
1718
#include "llvm/Support/Debug.h"
1819
#include "llvm/Support/ErrorHandling.h"
1920
#include "llvm/Support/raw_ostream.h"
@@ -212,23 +213,26 @@ struct CoalescePass
212213
user->dumpPretty();
213214
});
214215

215-
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
216-
propagateLayoutToArgsAndBody(forOp, val, layout, rewriter);
217-
continue;
218-
}
219-
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
220-
propagateLayoutToArgsAndBody(whileOp, val, layout, rewriter);
216+
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(user)) {
217+
propagateLayoutToArgsAndBody(loopOp, val, layout, rewriter);
221218
continue;
222219
}
223220
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
224-
if (auto loopOp = yieldOp->getParentOfType<LoopLikeOpInterface>()) {
225-
for (OpOperand &operand : llvm::make_filter_range(
226-
yieldOp->getOpOperands(),
227-
[&val](OpOperand &operand) { return operand.get() == val; }))
228-
propagateLayoutToLoopResult(loopOp, operand.getOperandNumber(),
229-
layout, rewriter);
230-
continue;
231-
}
221+
Operation *parentOp = yieldOp->getParentOp();
222+
for (OpOperand &operand : llvm::make_filter_range(
223+
yieldOp->getOpOperands(),
224+
[&val](OpOperand &operand) { return operand.get() == val; }))
225+
TypeSwitch<Operation *>(parentOp)
226+
.Case<LoopLikeOpInterface, scf::IfOp>([&](auto op) {
227+
propagateLayoutToOperationResult(op, operand.getOperandNumber(),
228+
layout, rewriter);
229+
})
230+
.Default([](auto op) {
231+
llvm::report_fatal_error(llvm::Twine(
232+
"Unsupported parent operation for scf.yield: '" +
233+
op->getName().getStringRef() + "'"));
234+
});
235+
continue;
232236
}
233237
if (auto condOp = dyn_cast<scf::ConditionOp>(user)) {
234238
if (auto whileOp = condOp->getParentOfType<scf::WhileOp>()) {
@@ -295,13 +299,16 @@ struct CoalescePass
295299
}
296300
}
297301

298-
// Modify the \p layout to the loop's operand identified by \p resNum, and
299-
// propagate the modified loop results to its users.
300-
void propagateLayoutToLoopResult(LoopLikeOpInterface loopOp, unsigned resNum,
301-
Attribute layout,
302-
IRRewriter &rewriter) const {
303-
Value loopRes = loopOp->getResult(resNum);
304-
rewriter.modifyOpInPlace(loopOp, [&]() {
302+
// Modify the \p layout of the operation \p op result identified by \p resNum,
303+
// and propagate the modified operation result to its users.
304+
template <typename OpType,
305+
typename = std::enable_if_t<
306+
llvm::is_one_of<OpType, LoopLikeOpInterface, scf::IfOp>::value>>
307+
void propagateLayoutToOperationResult(OpType op, unsigned resNum,
308+
Attribute layout,
309+
IRRewriter &rewriter) const {
310+
Value loopRes = op->getResult(resNum);
311+
rewriter.modifyOpInPlace(op, [&]() {
305312
assert(tt::isTensorPointerType(loopRes.getType()) &&
306313
"Expecting blocked pointers");
307314
Type resType = loopRes.getType();

0 commit comments

Comments
 (0)