Skip to content

Commit 8766024

Browse files
authored
[Utility] fix pass-by-reference addIterArgsToLoop API (#7029)
This API with the `scf::ForOp&` passed by reference is a foot-gun so this PR refactors to return the new `ForOp` instead.
1 parent ff86d26 commit 8766024

File tree

8 files changed

+16
-16
lines changed

8 files changed

+16
-16
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ scf::ForOp replaceForOpWithNewSignature(
141141
SmallVectorImpl<std::tuple<Value, Value>> &replacements);
142142
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
143143
ValueRange newIterOperands);
144-
Block::BlockArgListType addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp &loop,
145-
ValueRange newIterOperands);
144+
[[nodiscard]] scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop,
145+
ValueRange newIterOperands);
146146

147147
// Replace WhileOp with a new WhileOp with extra operands. The YieldOp is not
148148
// updated and needs to be updated separately for the loop to be correct.

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ ttng::TMEMAllocOp hoistTMEMAlloc(TMEMTokenAllocOp alloc, scf::ForOp &forOp) {
337337
// By hoisting the allocation out of the loop, we need to turn the underlying
338338
// memory variable into a loop-carried depdendency.
339339
auto tokType = builder.getType<AsyncTokenType>();
340-
Value newTok = addIterArgsToLoop(builder, forOp, newAlloc.getToken()).front();
340+
forOp = addIterArgsToLoop(builder, forOp, newAlloc.getToken());
341+
Value newTok = forOp.getRegionIterArgs().back();
341342
appendToForOpYield(forOp, joinLastMemoryUses(builder, alloc.getToken()));
342343

343344
if (src != nullptr) {

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ class OptimizeAccumulatorInitPass
249249
}
250250

251251
Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue;
252-
(void)addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue});
252+
forOp = addIterArgsToLoop(rewriter, forOp, {loopArgFlagValue});
253253
loopArgFlagValue =
254254
forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1);
255255

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
556556
}
557557

558558
// Patch the loop to add the new loop carried dependencies.
559-
(void)addIterArgsToLoop(builder, forOp, newOperands);
559+
forOp = addIterArgsToLoop(builder, forOp, newOperands);
560560

561561
// Update yield op with temporary yield values
562562
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
@@ -750,7 +750,7 @@ scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule) {
750750
newOperands.push_back(zero);
751751
}
752752

753-
(void)addIterArgsToLoop(builder, forOp, newOperands);
753+
forOp = addIterArgsToLoop(builder, forOp, newOperands);
754754

755755
auto tmaCounters = ArrayRef<BlockArgument>(forOp.getBody()->getArguments())
756756
.slice(tmaCounterArgsStartIdx);

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ Value triton::sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out,
106106
// `in` is live into the loop body. `out` becomes the live-out if the
107107
// loop executes at least once.
108108
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
109-
(void)addIterArgsToLoop(rewriter, forOp, in);
109+
forOp = addIterArgsToLoop(rewriter, forOp, in);
110110
appendToForOpYield(forOp, out);
111111
out = forOp.getResults().back();
112112
continue;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,17 +682,15 @@ scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
682682
return newForOp;
683683
}
684684

685-
Block::BlockArgListType addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp &loop,
686-
ValueRange newIterOperands) {
687-
unsigned curArgIdx = loop.getNumRegionIterArgs();
685+
scf::ForOp addIterArgsToLoop(OpBuilder &rewriter, scf::ForOp loop,
686+
ValueRange newIterOperands) {
688687
scf::ForOp newLoop =
689688
replaceForOpWithNewSignature(rewriter, loop, newIterOperands);
690689
// Save the caller from insertion point invalidation.
691690
if (rewriter.getInsertionPoint() == loop->getIterator())
692691
rewriter.setInsertionPoint(newLoop);
693692
loop.erase();
694-
loop = newLoop;
695-
return loop.getRegionIterArgs().slice(curArgIdx);
693+
return newLoop;
696694
}
697695

698696
scf::WhileOp replaceWhileOpWithNewSignature(

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ addIndexAndPhase(PartitionBuilder &b, scf::ForOp &loop, unsigned numStages,
118118
b.setInsertionPoint(loop);
119119

120120
// Index and phase both start at 0.
121-
unsigned curArgIdx = loop.getNumRegionIterArgs();
122-
auto newArgs = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
121+
loop = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
122+
auto newArgs = loop.getRegionIterArgs().take_back(2);
123123
BlockArgument index = newArgs[0];
124124
BlockArgument phase = newArgs[1];
125125

@@ -488,7 +488,8 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
488488
createTMemAlloc(b, oldAllocOp, /*multiBuffered=*/true, numMmaStages);
489489

490490
// Use placeholder values for the indices in the loop.
491-
auto indexPhase = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
491+
loop = addIterArgsToLoop(b, loop, {b.intCst(0), b.intCst(0)});
492+
auto indexPhase = loop.getRegionIterArgs().take_back(2);
492493
BlockArgument index = indexPhase[0];
493494
BlockArgument phase = indexPhase[1];
494495

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ void StreamPipeliner::createStreamOps() {
893893

894894
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
895895
// Patch the loop to add the new loop carried dependencies.
896-
(void)addIterArgsToLoop(builder, forOp, {extractIdx});
896+
forOp = addIterArgsToLoop(builder, forOp, {extractIdx});
897897

898898
// Create one counter for the extract indices to avoid creating long
899899
// live range.

0 commit comments

Comments
 (0)