Skip to content

Commit f1307bd

Browse files
authored
Improve Triton's coalescing pass to support propagation of layout for operations yielding multiple values (#4855)
This PR improves Triton's coalescing pass to support propagation of layout for operations that yield multiple values. The main enhancement allows the coalescing algorithm to correctly handle operations with multiple results by tracking individual values rather than assuming single-result operations. Key changes: - Refactored layout propagation methods to accept specific Value parameters instead of assuming single results - Added support for operations with multiple return values in the coalescing pass - Enhanced loop handling for scf::ForOp and scf::WhileOp to properly map individual operands to results Fixes issues #4854, #4817. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 3caf1e9 commit f1307bd

File tree

2 files changed

+128
-97
lines changed

2 files changed

+128
-97
lines changed

test/TritonIntelGPU/coalesce.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
558558
tt.return
559559
}
560560
}
561+
562+
// -----
563+
564+
// COM: Reproducer for issue #4854
565+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
566+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
567+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
568+
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
569+
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
570+
// CHECK: @test_4854
571+
tt.func public @test_4854(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
572+
%c0_i32 = arith.constant 0 : i32
573+
%c16_i32 = arith.constant 16 : i32
574+
%c128_i64 = arith.constant 128 : i64
575+
%c1_i64 = arith.constant 1 : i64
576+
%c32_i32 = arith.constant 32 : i32
577+
%0 = tt.make_tensor_ptr %arg0, [%c128_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<128x32xf32, #blocked>>
578+
%1 = tt.make_tensor_ptr %arg1, [%c128_i64, %c128_i64], [%c1_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x128xf32, #blocked1>>
579+
%2:2 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c32_i32 iter_args(%arg3 = %0, %arg4 = %1) -> (!tt.ptr<tensor<128x32xf32, #blocked>>, !tt.ptr<tensor<32x128xf32, #blocked1>>) : i32 {
580+
%5 = tt.advance %arg4, [%c32_i32, %c0_i32] : <tensor<32x128xf32, #blocked1>>
581+
scf.yield %arg3, %5 : !tt.ptr<tensor<128x32xf32, #blocked>>, !tt.ptr<tensor<32x128xf32, #blocked1>>
582+
}
583+
// CHECK: [[ADV:%.*]] = tt.advance {{.*}} : <tensor<128x32xf32, [[BLOCKED_LAYOUT]]>>
584+
%3 = tt.advance %2#0, [%c0_i32, %c16_i32] : <tensor<128x32xf32, #blocked>>
585+
// CHECK: [[LOAD:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<32x128xf32, [[BLOCKED_LAYOUT1]]>>
586+
%4 = tt.load %1 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<32x128xf32, #blocked1>>
587+
tt.return
588+
}
589+
}

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 99 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
#include "mlir/IR/Operation.h"
88
#include "mlir/IR/Value.h"
99
#include "mlir/IR/Verifier.h"
10+
#include "mlir/Interfaces/LoopLikeInterface.h"
1011
#include "mlir/Support/LLVM.h"
1112
#include "triton/Dialect/Triton/IR/Types.h"
1213
#include "triton/Dialect/Triton/IR/Utility.h"
1314
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1415
#include "triton/Tools/StrUtil.h"
16+
#include "llvm/ADT/STLExtras.h"
1517
#include "llvm/Support/Debug.h"
1618
#include "llvm/Support/ErrorHandling.h"
1719
#include "llvm/Support/raw_ostream.h"
@@ -154,55 +156,58 @@ struct CoalescePass
154156
return false;
155157
}
156158

157-
// Change the \p layout of the \p op result and propagate the new result type
158-
// to its users.
159-
void changeAndPropagateLayout(Operation *op, Attribute layout,
159+
// Change the \p layout of the \p op's result \p opRes and propagate the new
160+
// result type to its users.
161+
void changeAndPropagateLayout(Operation *op, Value opRes, Attribute layout,
160162
IRRewriter &rewriter) const {
161163
assert(op && op->getNumResults() != 0 &&
162164
"Expecting operation yielding results");
163165

164166
LLVM_DEBUG({
165167
llvm::dbgs() << "[" DEBUG_TYPE "]: " << "ChangeAndPropagateLayout for: ";
166168
op->dumpPretty();
169+
llvm::dbgs() << "opRes: ";
170+
opRes.printAsOperand(llvm::dbgs(), {});
171+
llvm::dbgs() << "\n";
167172
});
168173

169174
rewriter.modifyOpInPlace(op, [&]() {
170-
for (Value res : op->getResults()) {
171-
if (!tt::isTensorPointerType(res.getType()))
172-
continue;
173-
174-
auto ptrType = cast<tt::PointerType>(res.getType());
175-
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
176-
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
175+
assert(tt::isTensorPointerType(opRes.getType()));
176+
auto ptrType = cast<tt::PointerType>(opRes.getType());
177+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
178+
opRes.setType(tt::PointerType::get(getNewType(tensorType, layout),
177179
ptrType.getAddressSpace()));
178-
}
179180
});
180181

181182
LLVM_DEBUG({
182183
llvm::dbgs() << "[" DEBUG_TYPE "]: Coalesced op: ";
183184
op->dumpPretty();
184185
});
185186

186-
propagateLayout(op, layout, rewriter);
187+
for (OpResult res : op->getResults())
188+
if (res == opRes)
189+
propagateLayout(op, res, layout, rewriter);
187190
}
188191

189192
// Propagate the layout of the \p root operation's result to its users.
190-
void propagateLayout(Operation *root, Attribute layout,
193+
void propagateLayout(Operation *op, Value opRes, Attribute layout,
191194
IRRewriter &rewriter) const {
192-
assert(root->getNumResults() != 0 &&
195+
assert(op && op->getNumResults() != 0 &&
193196
"Expecting an operation yielding a result");
194-
195-
auto mod = root->getParentOfType<ModuleOp>();
197+
assert(opRes &&
198+
llvm::any_of(op->getResults(),
199+
[&](OpResult res) { return res == opRes; }) &&
200+
"Expecting operation to yield 'opRes'");
196201

197202
LLVM_DEBUG({
198-
if (!root->getUsers().empty()) {
203+
if (!opRes.getUsers().empty()) {
199204
llvm::dbgs() << "[" DEBUG_TYPE "]: "
200-
<< "Propagate layout to operations using: ";
201-
root->dumpPretty();
205+
<< "Propagate layout to operations using: " << opRes
206+
<< "\n";
202207
}
203208
});
204209

205-
for (Operation *user : root->getUsers()) {
210+
for (Operation *user : opRes.getUsers()) {
206211
if (filterUser(user))
207212
continue;
208213

@@ -212,50 +217,71 @@ struct CoalescePass
212217
});
213218

214219
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
215-
propagateLayoutToArgsAndBody(forOp, root, layout, rewriter);
220+
propagateLayoutToArgsAndBody(forOp, opRes, layout, rewriter);
216221
continue;
217222
}
218223
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
219-
propagateLayoutToArgsAndBody(whileOp, root, layout, rewriter);
224+
propagateLayoutToArgsAndBody(whileOp, opRes, layout, rewriter);
220225
continue;
221226
}
222-
223227
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
224-
if (auto forOp = yieldOp->getParentOfType<scf::ForOp>())
225-
propagateLayoutToLoopResults(forOp, layout, rewriter);
226-
if (auto whileOp = yieldOp->getParentOfType<scf::WhileOp>())
227-
propagateLayoutToLoopResults(whileOp, layout, rewriter);
228-
continue;
228+
if (auto loopOp = yieldOp->getParentOfType<LoopLikeOpInterface>()) {
229+
for (OpOperand &operand : yieldOp->getOpOperands())
230+
if (operand.get() == opRes)
231+
propagateLayoutToLoopResults(loopOp, operand.getOperandNumber(),
232+
layout, rewriter);
233+
continue;
234+
}
229235
}
230236

231237
LLVM_DEBUG({
232238
llvm::dbgs() << "[" DEBUG_TYPE "]: After propagating layout:\n";
233-
mod->dumpPretty();
239+
op->getParentOfType<ModuleOp>()->dumpPretty();
234240
});
235241

236-
changeAndPropagateLayout(user, layout, rewriter);
242+
for (OpResult res : user->getResults())
243+
changeAndPropagateLayout(user, res, layout, rewriter);
237244
}
238245
}
239246

240247
// Propagate the layout of the \p arg block argument to its users.
241248
void propagateLayout(BlockArgument arg, Attribute layout,
242249
IRRewriter &rewriter) const {
250+
LLVM_DEBUG({
251+
if (!arg.getUsers().empty()) {
252+
llvm::dbgs() << "[" DEBUG_TYPE "]: "
253+
<< "Propagate layout to operations using: ";
254+
arg.printAsOperand(llvm::dbgs(), {});
255+
llvm::dbgs() << "\n";
256+
}
257+
});
258+
243259
for (Operation *user : arg.getUsers()) {
244260
if (filterUser(user))
245261
continue;
246262

247263
LLVM_DEBUG({
248-
llvm::dbgs() << "[" DEBUG_TYPE "]: " << "arg's user: ";
264+
llvm::dbgs() << "[" DEBUG_TYPE "]: " << "user: ";
249265
user->dumpPretty();
250266
});
251267

252-
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
253-
if (auto forOp = yieldOp->getParentOfType<scf::ForOp>())
254-
propagateLayoutToLoopResults(forOp, layout, rewriter);
255-
if (auto whileOp = yieldOp->getParentOfType<scf::WhileOp>())
256-
propagateLayoutToLoopResults(whileOp, layout, rewriter);
268+
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
269+
propagateLayoutToArgsAndBody(forOp, arg, layout, rewriter);
257270
continue;
258271
}
272+
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
273+
propagateLayoutToArgsAndBody(whileOp, arg, layout, rewriter);
274+
continue;
275+
}
276+
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
277+
if (auto loopOp = yieldOp->getParentOfType<LoopLikeOpInterface>()) {
278+
for (OpOperand &operand : yieldOp->getOpOperands())
279+
if (operand.get() == arg)
280+
propagateLayoutToLoopResults(loopOp, operand.getOperandNumber(),
281+
layout, rewriter);
282+
continue;
283+
}
284+
}
259285
if (auto condOp = dyn_cast<scf::ConditionOp>(user)) {
260286
if (auto whileOp = condOp->getParentOfType<scf::WhileOp>()) {
261287
// Propagate layout to "after" region arguments.
@@ -284,7 +310,8 @@ struct CoalescePass
284310
continue;
285311
}
286312

287-
changeAndPropagateLayout(user, layout, rewriter);
313+
for (OpResult res : user->getResults())
314+
changeAndPropagateLayout(user, res, layout, rewriter);
288315
}
289316

290317
LLVM_DEBUG({
@@ -300,74 +327,48 @@ struct CoalescePass
300327
// loop body that use that argument.
301328
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
302329
OpType, scf::ForOp, scf::WhileOp>::value>>
303-
void propagateLayoutToArgsAndBody(OpType loopOp, Operation *root,
330+
void propagateLayoutToArgsAndBody(OpType loopOp, Value opRes,
304331
Attribute layout,
305332
IRRewriter &rewriter) const {
306-
assert(llvm::any_of(root->getUsers(),
307-
[&](Operation *user) { return user == loopOp; }) &&
308-
"Expecting the loop to be a user of the root operation");
309-
310-
for (BlockArgument arg : loopOp.getRegionIterArgs()) {
311-
Value loopArg;
312-
if constexpr (std::is_same<OpType, scf::ForOp>::value)
313-
loopArg = loopOp.getInitArgs()[arg.getArgNumber() - 1];
314-
if constexpr (std::is_same<OpType, scf::WhileOp>::value)
315-
loopArg = loopOp.getInits()[arg.getArgNumber()];
316-
317-
for (OpResult res : root->getResults()) {
318-
if (res != loopArg || !tt::isTensorPointerType(res.getType()))
319-
continue;
320-
// Modify the layout of the loop init argument...
321-
tt::PointerType ptrType = cast<tt::PointerType>(arg.getType());
322-
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
323-
arg.setType(tt::PointerType::get(getNewType(tensorType, layout),
324-
ptrType.getAddressSpace()));
325-
LLVM_DEBUG({
326-
llvm::dbgs() << "[" DEBUG_TYPE "]: " << "Propagated layout to: ";
327-
arg.printAsOperand(llvm::dbgs(), {});
328-
llvm::dbgs() << "\n";
329-
});
330-
331-
// ... and then propagate it to the operations in the loop.
332-
propagateLayout(arg, layout, rewriter);
333-
}
333+
for (auto [initArg, arg] :
334+
llvm::zip(loopOp.getInitsMutable(), loopOp.getRegionIterArgs())) {
335+
if (initArg.get() != opRes)
336+
continue;
337+
338+
// Modify the layout of the loop init argument...
339+
auto ptrType = cast<tt::PointerType>(arg.getType());
340+
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
341+
arg.setType(tt::PointerType::get(getNewType(tensorType, layout),
342+
ptrType.getAddressSpace()));
343+
344+
LLVM_DEBUG({
345+
llvm::dbgs() << "[" DEBUG_TYPE "]: " << "Propagated layout to: ";
346+
arg.printAsOperand(llvm::dbgs(), {});
347+
llvm::dbgs() << "\n";
348+
});
349+
350+
// ... and then propagate it to the operations in the loop.
351+
propagateLayout(arg, layout, rewriter);
334352
}
335353
}
336354

337-
// Modify the given loop \p loopOpt and propagate its results to their users.
338-
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
339-
OpType, scf::ForOp, scf::WhileOp>::value>>
340-
void propagateLayoutToLoopResults(OpType loopOp, Attribute layout,
355+
// Modify the \p layout to the loop's operand identified by \p resNum, and
356+
// propagate the modified loop results to its users.
357+
void propagateLayoutToLoopResults(LoopLikeOpInterface loopOp, unsigned resNum,
358+
Attribute layout,
341359
IRRewriter &rewriter) const {
342-
Operation *yieldOp = nullptr;
343-
if constexpr (std::is_same<OpType, scf::ForOp>::value)
344-
yieldOp = loopOp.getBody()->getTerminator();
345-
if constexpr (std::is_same<OpType, scf::WhileOp>::value)
346-
yieldOp = loopOp.getYieldOp();
347-
360+
Value loopRes = loopOp->getResult(resNum);
348361
rewriter.modifyOpInPlace(loopOp, [&]() {
349-
for (auto [yieldOperandType, res] :
350-
llvm::zip(yieldOp->getOperandTypes(), loopOp.getResults())) {
351-
Type resType = res.getType();
352-
if (yieldOperandType == resType)
353-
continue;
354-
355-
assert(tt::isTensorPointerType(resType) &&
356-
tt::isTensorPointerType(yieldOperandType) &&
357-
"Expecting blocked pointers");
358-
assert(cast<RankedTensorType>(
359-
cast<tt::PointerType>(yieldOperandType).getPointeeType())
360-
.getEncoding() == layout &&
361-
"Unexpected layout");
362-
363-
auto ptrType = cast<tt::PointerType>(res.getType());
364-
RankedTensorType tensorType = ttgi::getRankedTensorType(resType);
365-
res.setType(tt::PointerType::get(getNewType(tensorType, layout),
366-
ptrType.getAddressSpace()));
367-
}
362+
assert(tt::isTensorPointerType(loopRes.getType()) &&
363+
"Expecting blocked pointers");
364+
Type resType = loopRes.getType();
365+
auto ptrType = cast<tt::PointerType>(resType);
366+
RankedTensorType tensorType = ttgi::getRankedTensorType(resType);
367+
loopRes.setType(tt::PointerType::get(getNewType(tensorType, layout),
368+
ptrType.getAddressSpace()));
368369
});
369370

370-
propagateLayout(loopOp, layout, rewriter);
371+
propagateLayout(loopOp, loopRes, layout, rewriter);
371372
}
372373

373374
void coalesceOp(Attribute encoding, Operation *op) {
@@ -404,7 +405,8 @@ struct CoalescePass
404405
}
405406

406407
IRRewriter rewriter(builder);
407-
changeAndPropagateLayout(*defOp, encoding, rewriter);
408+
changeAndPropagateLayout(*defOp, defOp->getResult(), encoding,
409+
rewriter);
408410
newArgs.push_back(operand);
409411
}
410412
}

0 commit comments

Comments
 (0)