Skip to content

Commit 3728fdf

Browse files
authored
[AMD] Combine redundant AsyncWaits in StreamPipeliner (#6435)
Moved `combineRedundantWaitOps` from `WGMMAPipeline` to `PipeliningUtility` to reuse it in the Streampipeliner.
1 parent 61cb963 commit 3728fdf

File tree

4 files changed

+45
-36
lines changed

4 files changed

+45
-36
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ Value createAlloc(scf::ForOp forOp, RankedTensorType ty, Location loc,
7373
// Determine if the operation is a TMA load.
7474
bool isTMALoad(Operation *op);
7575

76+
// Look for consecutive wait ops and combine them into a single wait op.
77+
void combineRedundantWaitOps(
78+
llvm::SmallSetVector<gpu::AsyncWaitOp, 8> &waitOps);
79+
7680
// Get the type of the view of a multi-buffered tensor value.
7781
gpu::MemDescType getBufferViewType(gpu::MemDescType allocTy);
7882
// Get a generic shared encoding for a tensor.

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,40 @@ bool mlir::triton::isTMALoad(Operation *op) {
327327
return isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op);
328328
}
329329

330+
void mlir::triton::combineRedundantWaitOps(
331+
llvm::SmallSetVector<ttg::AsyncWaitOp, 8> &waitOps) {
332+
llvm::MapVector<ttg::AsyncWaitOp, ttg::AsyncWaitOp> toDelete;
333+
for (auto waitOp : waitOps) {
334+
if (toDelete.count(waitOp))
335+
continue;
336+
SmallVector<ttg::AsyncWaitOp> waitGroup = {waitOp};
337+
SmallVector<Value> depTokens = waitOp.getOperands();
338+
unsigned minWaitNumber = waitOp.getNum();
339+
Operation *next = waitOp->getNextNode();
340+
while (next && !isa<ttg::AsyncCommitGroupOp>(next)) {
341+
if (auto nextWait = dyn_cast<ttg::AsyncWaitOp>(next)) {
342+
waitGroup.push_back(nextWait);
343+
minWaitNumber = std::min(minWaitNumber, nextWait.getNum());
344+
depTokens.append(nextWait.getOperands().begin(),
345+
nextWait.getOperands().end());
346+
}
347+
next = next->getNextNode();
348+
}
349+
if (waitGroup.size() == 1)
350+
continue;
351+
OpBuilder builder(waitGroup.front());
352+
auto newWaitOp = builder.create<ttg::AsyncWaitOp>(waitOp.getLoc(),
353+
depTokens, minWaitNumber);
354+
for (auto waitOp : waitGroup) {
355+
toDelete[waitOp] = newWaitOp;
356+
}
357+
}
358+
for (auto waitOp : toDelete) {
359+
waitOp.first->replaceAllUsesWith(waitOp.second);
360+
waitOp.first->erase();
361+
}
362+
}
363+
330364
ttg::MemDescType mlir::triton::getBufferViewType(ttg::MemDescType allocTy) {
331365
Attribute sharedMemorySpace =
332366
ttg::SharedMemorySpaceAttr::get(allocTy.getContext());

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -100,41 +100,6 @@ static int minNumInterleavedCommitOps(Operation *waitOp) {
100100
return minCommits;
101101
}
102102

103-
// Look for consecutive wait ops and combine them into a single wait op.
104-
static void
105-
combineRedundantWaitOps(llvm::SmallSetVector<ttg::AsyncWaitOp, 8> &waitOps) {
106-
llvm::MapVector<ttg::AsyncWaitOp, ttg::AsyncWaitOp> toDelete;
107-
for (auto waitOp : waitOps) {
108-
if (toDelete.count(waitOp))
109-
continue;
110-
SmallVector<ttg::AsyncWaitOp> waitGroup = {waitOp};
111-
SmallVector<Value> depTokens = waitOp.getOperands();
112-
unsigned minWaitNumber = waitOp.getNum();
113-
Operation *next = waitOp->getNextNode();
114-
while (next && !isa<ttg::AsyncCommitGroupOp>(next)) {
115-
if (auto nextWait = dyn_cast<ttg::AsyncWaitOp>(next)) {
116-
waitGroup.push_back(nextWait);
117-
minWaitNumber = std::min(minWaitNumber, nextWait.getNum());
118-
depTokens.append(nextWait.getOperands().begin(),
119-
nextWait.getOperands().end());
120-
}
121-
next = next->getNextNode();
122-
}
123-
if (waitGroup.size() == 1)
124-
continue;
125-
OpBuilder builder(waitGroup.front());
126-
auto newWaitOp = builder.create<ttg::AsyncWaitOp>(waitOp.getLoc(),
127-
depTokens, minWaitNumber);
128-
for (auto waitOp : waitGroup) {
129-
toDelete[waitOp] = newWaitOp;
130-
}
131-
}
132-
for (auto waitOp : toDelete) {
133-
waitOp.first->replaceAllUsesWith(waitOp.second);
134-
waitOp.first->erase();
135-
}
136-
}
137-
138103
/// Update wait op number by analyzing the number of async_commit_group ops
139104
/// along all paths.
140105
void mlir::triton::updateWaits(ModuleOp module) {
@@ -144,7 +109,7 @@ void mlir::triton::updateWaits(ModuleOp module) {
144109
waitOp.setNum(minNumCommits);
145110
waitOps.insert(waitOp);
146111
});
147-
combineRedundantWaitOps(waitOps);
112+
tt::combineRedundantWaitOps(waitOps);
148113
}
149114

150115
// Add the given values as operands of the given wait, and replace all uses of

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,12 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
10501050
globalPrefetch, localPrefetch, useAsyncCopy);
10511051
(void)sp.pipelineLoop();
10521052
}
1053+
1054+
if (useAsyncCopy) {
1055+
llvm::SmallSetVector<ttg::AsyncWaitOp, 8> waitOps;
1056+
moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); });
1057+
tt::combineRedundantWaitOps(waitOps);
1058+
}
10531059
}
10541060
};
10551061
} // namespace

0 commit comments

Comments
 (0)