1
- #include " third_party/amd/include/ TritonAMDGPUToLLVM/MembarUtility.h"
1
+ #include " TritonAMDGPUToLLVM/MembarUtility.h"
2
2
#include " Dialect/TritonAMDGPU/IR/Dialect.h"
3
- #include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4
3
#include " triton/Dialect/TritonGPU/IR/Dialect.h"
5
4
6
5
namespace mlir ::triton::AMD {
7
6
namespace {
7
+ constexpr const char *syncedViaAsyncWaitAttrName =
8
+ " ttg.amdgpu.syncedViaAsyncWait" ;
8
9
9
10
// Traverses the def-chain including control flow of the token and returns true
10
11
// if all defining operations are an AsyncWait
@@ -31,16 +32,12 @@ bool comesFromAsyncWait(Value token) {
31
32
// argId to see if they are immediately an AsyncWait.
32
33
for (auto *pred : block->getPredecessors ()) {
33
34
auto terminator = pred->getTerminator ();
34
- if (auto br = dyn_cast<cf::BranchOp>(terminator)) {
35
- if (!destOperandFromAsyncWait (br.getDestOperands ()))
36
- return false ;
37
- } else if (auto condBr = dyn_cast<cf::CondBranchOp>(terminator)) {
38
- if (condBr.getTrueDest () == block) {
39
- if (!destOperandFromAsyncWait (condBr.getTrueDestOperands ()))
40
- return false ;
41
- }
42
- if (condBr.getFalseDest () == block) {
43
- if (!destOperandFromAsyncWait (condBr.getFalseDestOperands ()))
35
+ if (auto br = dyn_cast<BranchOpInterface>(terminator)) {
36
+ for (auto successor : llvm::enumerate (br->getSuccessors ())) {
37
+ if (block != successor.value ())
38
+ continue ;
39
+ auto operands = br.getSuccessorOperands (successor.index ());
40
+ if (!destOperandFromAsyncWait (operands))
44
41
return false ;
45
42
}
46
43
} else {
@@ -51,19 +48,14 @@ bool comesFromAsyncWait(Value token) {
51
48
}
52
49
53
50
// Returns true if one of the operands is a LocalLoad synced via AsyncWait.
54
- bool filterAsyncLocalLoadsDeppendencies (Operation *op1, Operation *op2) {
51
+ bool filterAsyncLocalLoadsDependencies (Operation *op1, Operation *op2) {
55
52
auto isAsyncLoad = [](Operation *op) {
56
53
return llvm::isa<triton::gpu::AsyncCopyGlobalToLocalOp,
57
54
triton::amdgpu::BufferLoadToLocalOp>(op);
58
55
};
59
56
auto isLocalLoadWithAsyncWaitToken = [](Operation *op) {
60
57
auto localLoad = llvm::dyn_cast<triton::gpu::LocalLoadOp>(op);
61
- if (!localLoad)
62
- return false ;
63
- Value token = localLoad.getToken ();
64
- if (!token || !comesFromAsyncWait (token))
65
- return false ;
66
- return true ;
58
+ return localLoad && isSyncedViaAsyncWait (localLoad);
67
59
};
68
60
69
61
// Early return if neither or both operands are an AsyncLoad
@@ -76,7 +68,33 @@ bool filterAsyncLocalLoadsDeppendencies(Operation *op1, Operation *op2) {
76
68
};
77
69
} // namespace
78
70
71
+ void annotateLocalLoadsSyncedViaAsyncWait (ModuleOp mod) {
72
+ SmallVector<triton::gpu::LocalLoadOp> localLoads;
73
+ mod->walk ([&](triton::gpu::LocalLoadOp localLoadOp) {
74
+ localLoads.emplace_back (localLoadOp);
75
+ });
76
+
77
+ auto *ctx = mod->getContext ();
78
+ for (auto &loadOp : localLoads) {
79
+ auto token = loadOp.getToken ();
80
+ bool isSyncedViaAsyncWait = token && comesFromAsyncWait (token);
81
+ loadOp->setAttr (syncedViaAsyncWaitAttrName,
82
+ BoolAttr::get (ctx, isSyncedViaAsyncWait));
83
+ }
84
+ }
85
+
86
+ bool isSyncedViaAsyncWait (triton::gpu::LocalLoadOp localLoadOp) {
87
+ auto attr = localLoadOp->getAttr (syncedViaAsyncWaitAttrName);
88
+ if (!attr) {
89
+ localLoadOp.emitRemark (" has no async sync information attached to it which "
90
+ " might negatively affect performance. Run "
91
+ " annotateLocalLoadSyncedViaAsyncWait first" );
92
+ return false ;
93
+ }
94
+ return cast<BoolAttr>(attr).getValue ();
95
+ }
96
+
79
97
bool membarFilter (Operation *op1, Operation *op2) {
80
- return filterAsyncLocalLoadsDeppendencies (op1, op2);
98
+ return filterAsyncLocalLoadsDependencies (op1, op2);
81
99
}
82
100
} // namespace mlir::triton::AMD
0 commit comments