Skip to content

Commit 859dcf0

Browse files
authored
[AMD] Refactor Membar filter for LocalLoads synced via AsyncWait (#7047)
Instead of walking the def-chain of the `AsyncToken` inside the membar filter we do it once before running membar analysis. Also makes the branch handling more generic by using `BranchOpInterface` instead of the specific branch instructions. This also allows us to reuse the information when adding alias information while lowering `LocalLoads` which will be enabled in a follow up PR.
1 parent 4f51f8d commit 859dcf0

File tree

4 files changed

+51
-22
lines changed

4 files changed

+51
-22
lines changed

third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_MEMBARUTILITY_H_
22
#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_MEMBARUTILITY_H_
33

4+
#include "mlir/IR/BuiltinOps.h"
45
#include "mlir/IR/Operation.h"
6+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
57

68
namespace mlir::triton::AMD {
9+
10+
// Annotates LocalLoadOps with ttg.amdgpu.syncedByAsyncWait=true if they are
11+
// synced by an AsyncWait.
12+
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod);
13+
14+
// Getter for the annotation applied by annotateLocalLoadsSyncedViaAsyncWait
15+
bool isSyncedViaAsyncWait(triton::gpu::LocalLoadOp localLoadOp);
16+
717
// Filter function used in the AMDGPU backend to filter unnecessary barriers
818
// during Membar Analysis. Filters applied by this function:
919
// 1) Do not create barriers between AsyncCopyGlobalToLocal and LocalLoad if the

third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
#include "third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h"
1+
#include "TritonAMDGPUToLLVM/MembarUtility.h"
22
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
3-
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
43
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
54

65
namespace mlir::triton::AMD {
76
namespace {
7+
constexpr const char *syncedViaAsyncWaitAttrName =
8+
"ttg.amdgpu.syncedViaAsyncWait";
89

910
// Traverses the def-chain including control flow of the token and returns true
1011
// if all defining operations are an AsyncWait
@@ -31,16 +32,12 @@ bool comesFromAsyncWait(Value token) {
3132
// argId to see if they are immediately an AsyncWait.
3233
for (auto *pred : block->getPredecessors()) {
3334
auto terminator = pred->getTerminator();
34-
if (auto br = dyn_cast<cf::BranchOp>(terminator)) {
35-
if (!destOperandFromAsyncWait(br.getDestOperands()))
36-
return false;
37-
} else if (auto condBr = dyn_cast<cf::CondBranchOp>(terminator)) {
38-
if (condBr.getTrueDest() == block) {
39-
if (!destOperandFromAsyncWait(condBr.getTrueDestOperands()))
40-
return false;
41-
}
42-
if (condBr.getFalseDest() == block) {
43-
if (!destOperandFromAsyncWait(condBr.getFalseDestOperands()))
35+
if (auto br = dyn_cast<BranchOpInterface>(terminator)) {
36+
for (auto successor : llvm::enumerate(br->getSuccessors())) {
37+
if (block != successor.value())
38+
continue;
39+
auto operands = br.getSuccessorOperands(successor.index());
40+
if (!destOperandFromAsyncWait(operands))
4441
return false;
4542
}
4643
} else {
@@ -51,19 +48,14 @@ bool comesFromAsyncWait(Value token) {
5148
}
5249

5350
// Returns true if one of the operands is a LocalLoad synced via AsyncWait.
54-
bool filterAsyncLocalLoadsDeppendencies(Operation *op1, Operation *op2) {
51+
bool filterAsyncLocalLoadsDependencies(Operation *op1, Operation *op2) {
5552
auto isAsyncLoad = [](Operation *op) {
5653
return llvm::isa<triton::gpu::AsyncCopyGlobalToLocalOp,
5754
triton::amdgpu::BufferLoadToLocalOp>(op);
5855
};
5956
auto isLocalLoadWithAsyncWaitToken = [](Operation *op) {
6057
auto localLoad = llvm::dyn_cast<triton::gpu::LocalLoadOp>(op);
61-
if (!localLoad)
62-
return false;
63-
Value token = localLoad.getToken();
64-
if (!token || !comesFromAsyncWait(token))
65-
return false;
66-
return true;
58+
return localLoad && isSyncedViaAsyncWait(localLoad);
6759
};
6860

6961
// Early return if neither or both operands are an AsyncLoad
@@ -76,7 +68,33 @@ bool filterAsyncLocalLoadsDeppendencies(Operation *op1, Operation *op2) {
7668
};
7769
} // namespace
7870

71+
void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) {
72+
SmallVector<triton::gpu::LocalLoadOp> localLoads;
73+
mod->walk([&](triton::gpu::LocalLoadOp localLoadOp) {
74+
localLoads.emplace_back(localLoadOp);
75+
});
76+
77+
auto *ctx = mod->getContext();
78+
for (auto &loadOp : localLoads) {
79+
auto token = loadOp.getToken();
80+
bool isSyncedViaAsyncWait = token && comesFromAsyncWait(token);
81+
loadOp->setAttr(syncedViaAsyncWaitAttrName,
82+
BoolAttr::get(ctx, isSyncedViaAsyncWait));
83+
}
84+
}
85+
86+
bool isSyncedViaAsyncWait(triton::gpu::LocalLoadOp localLoadOp) {
87+
auto attr = localLoadOp->getAttr(syncedViaAsyncWaitAttrName);
88+
if (!attr) {
89+
localLoadOp.emitRemark("has no async sync information attached to it which "
90+
"might negatively affect performance. Run "
91+
"annotateLocalLoadSyncedViaAsyncWait first");
92+
return false;
93+
}
94+
return cast<BoolAttr>(attr).getValue();
95+
}
96+
7997
bool membarFilter(Operation *op1, Operation *op2) {
80-
return filterAsyncLocalLoadsDeppendencies(op1, op2);
98+
return filterAsyncLocalLoadsDependencies(op1, op2);
8199
}
82100
} // namespace mlir::triton::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ struct ConvertTritonAMDGPUToLLVM
100100
// Allocate shared memory and set barrier
101101
ModuleAllocation allocation(mod);
102102

103+
AMD::annotateLocalLoadsSyncedViaAsyncWait(mod);
103104
ModuleMembarAnalysis membarPass(&allocation,
104105
mlir::triton::AMD::membarFilter);
105106
membarPass.run();

third_party/amd/test/lib/Analysis/TestAMDGPUMembar.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ struct TestAMDGPUMembarPass
2121

2222
void runOnOperation() override {
2323
ModuleOp moduleOp = getOperation();
24+
triton::AMD::annotateLocalLoadsSyncedViaAsyncWait(moduleOp);
2425
// Print all ops after membar pass
2526
ModuleAllocation allocation(moduleOp);
26-
ModuleMembarAnalysis membarPass(&allocation,
27-
mlir::triton::AMD::membarFilter);
27+
ModuleMembarAnalysis membarPass(&allocation, triton::AMD::membarFilter);
2828
membarPass.run();
2929
}
3030
};

0 commit comments

Comments
 (0)