Skip to content

Commit d4432f6

Browse files
authored
[BACKEND] Add a new pass to insert fence.proxy.async for write after read hazard (#7262)
When loading data from smem using the generic proxy then writing using the async proxy we need to insert a fence before writing. This case was being missed in our current analysis. This commit hads a new pass to conservatively insert fences in this kind of scenario. We still need the existing insertFence pass as it can be smarter about fence insertion because it runs before lowering of the structured control flow. This pass is meant to be more conservative while the previous pass can optimize better. Didn't see significant perf regressions so far.
1 parent bf5913d commit d4432f6

File tree

8 files changed

+322
-23
lines changed

8 files changed

+322
-23
lines changed

include/triton/Analysis/Membar.h

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ struct BlockInfo {
9595
//===----------------------------------------------------------------------===//
9696
// Shared Memory Barrier Analysis
9797
//===----------------------------------------------------------------------===//
98-
class MembarAnalysis {
98+
99+
// Common class to analyze membar and fence placement.
100+
class MembarOrFenceAnalysis {
99101
using VirtualBlock = std::pair<Block *, Block::iterator>;
100102

101103
public:
@@ -113,15 +115,15 @@ class MembarAnalysis {
113115
/// a shared memory read. If the temporary storage is written but not read,
114116
/// it is considered as the problem of the operation itself but not the membar
115117
/// analysis.
116-
MembarAnalysis() = default;
117-
explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter)
118+
MembarOrFenceAnalysis() = default;
119+
explicit MembarOrFenceAnalysis(Allocation *allocation, MembarFilterFn filter)
118120
: allocation(allocation), filter(filter) {}
119121

120122
/// Runs the membar analysis to the given operation, inserts a barrier if
121123
/// necessary.
122124
void run(FuncBlockInfoMapT &funcBlockInfoMap);
123125

124-
private:
126+
protected:
125127
/// Applies the barrier analysis based on the SCF dialect, in which each
126128
/// region has a single basic block only.
127129
/// Example:
@@ -139,30 +141,44 @@ class MembarAnalysis {
139141
void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap,
140142
OpBuilder *builder);
141143

142-
/// Updates the BlockInfo operation based on the operation.
143-
void update(Operation *operation, BlockInfo *blockInfo,
144-
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder);
145-
146144
/// Collects the successors of the terminator
147145
void visitTerminator(Operation *operation,
148146
SmallVector<VirtualBlock> &successors);
149147

150-
void insertBarrier(Operation *operation, OpBuilder *builder);
148+
/// Updates the BlockInfo operation based on the operation.
149+
virtual void update(Operation *operation, BlockInfo *blockInfo,
150+
FuncBlockInfoMapT *funcBlockInfoMap,
151+
OpBuilder *builder) = 0;
151152

152-
private:
153153
Allocation *allocation = nullptr;
154154
MembarFilterFn filter = nullptr;
155155
};
156156

157+
class MembarAnalysis : public MembarOrFenceAnalysis {
158+
public:
159+
MembarAnalysis() = default;
160+
explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter)
161+
: MembarOrFenceAnalysis(allocation, filter) {}
162+
163+
private:
164+
/// Updates the BlockInfo operation based on the operation.
165+
virtual void update(Operation *operation, BlockInfo *blockInfo,
166+
FuncBlockInfoMapT *funcBlockInfoMap,
167+
OpBuilder *builder) override;
168+
169+
void insertBarrier(Operation *operation, OpBuilder *builder);
170+
};
171+
157172
/// Postorder traversal on the callgraph to insert membar instructions
158173
/// of each function.
159174
/// Each function maintains a BlockInfo map that includes all potential buffers
160175
/// after returning. This way users do not have to explicitly insert membars
161176
/// before and after function calls, but might be a bit conservative.
162-
class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
177+
template <typename AnalysisType>
178+
class ModuleMembarOrFenceAnalysis : public CallGraph<BlockInfo> {
163179
public:
164-
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation,
165-
MembarFilterFn filter = nullptr)
180+
ModuleMembarOrFenceAnalysis(ModuleAllocation *moduleAllocation,
181+
MembarFilterFn filter = nullptr)
166182
: CallGraph<BlockInfo>(moduleAllocation->getModuleOp()),
167183
moduleAllocation(moduleAllocation), filter(filter) {}
168184

@@ -175,7 +191,7 @@ class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
175191
auto *allocation = moduleAllocation->getFuncData(funcOp);
176192
auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo());
177193
if (inserted) {
178-
MembarAnalysis analysis(allocation, filter);
194+
AnalysisType analysis(allocation, filter);
179195
analysis.run(funcMap);
180196
}
181197
});
@@ -186,6 +202,8 @@ class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
186202
MembarFilterFn filter;
187203
};
188204

205+
typedef ModuleMembarOrFenceAnalysis<MembarAnalysis> ModuleMembarAnalysis;
206+
189207
} // namespace mlir
190208

191209
#endif // TRITON_ANALYSIS_MEMBAR_H

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@ def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp">
4141
}
4242

4343
def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> {
44+
let summary = "Insert fences across generic and async proxy.";
45+
46+
let description = [{
47+
This pass is to insert memory fences to ensure that memory operations are
48+
properly ordered across generic and async operations.
49+
This pass inserts fences at optimized location.
50+
There is a pass later to handle all the functional requirements
51+
}];
52+
53+
let dependentDialects = [
54+
"mlir::triton::gpu::TritonGPUDialect",
55+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
56+
];
57+
58+
let options = [
59+
Option<"computeCapability", "compute-capability",
60+
"int32_t", /*default*/"90",
61+
"device compute capability">
62+
];
63+
}
64+
65+
def TritonGPUProxyFenceInsertion : Pass<"triton-nvidia-gpu-proxy-fence-insertion", "mlir::ModuleOp"> {
4466
let summary = "Insert fences across generic and async proxy";
4567

4668
let description = [{

lib/Analysis/Membar.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
namespace mlir {
1010

11-
void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
11+
void MembarOrFenceAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
1212
FunctionOpInterface funcOp =
1313
dyn_cast<FunctionOpInterface>(allocation->getOperation());
1414
OpBuilder builder(funcOp.getContext());
1515
resolve(funcOp, &funcBlockInfoMap, &builder);
1616
}
1717

18-
void MembarAnalysis::resolve(FunctionOpInterface funcOp,
19-
FuncBlockInfoMapT *funcBlockInfoMap,
20-
OpBuilder *builder) {
18+
void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp,
19+
FuncBlockInfoMapT *funcBlockInfoMap,
20+
OpBuilder *builder) {
2121
// Initialize the blockList. Operations are organized into "virtual blocks",
2222
// which represent segments of straight-line code analyzed by each iteration
2323
// of the dataflow analysis. Virtual blocks abstract over both control flow
@@ -103,8 +103,8 @@ void MembarAnalysis::resolve(FunctionOpInterface funcOp,
103103
});
104104
}
105105

106-
void MembarAnalysis::visitTerminator(Operation *op,
107-
SmallVector<VirtualBlock> &successors) {
106+
void MembarOrFenceAnalysis::visitTerminator(
107+
Operation *op, SmallVector<VirtualBlock> &successors) {
108108
if (isa<BranchOpInterface>(op)) {
109109
// Collect the block successors of the branch.
110110
for (Block *successor : op->getSuccessors())

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_triton_library(TritonNvidiaGPUTransforms
66
OptimizeTMemLayouts.cpp
77
PlanCTA.cpp
88
PromoteLHSToTMem.cpp
9+
ProxFenceInsertion.cpp
910
RemoveTMEMTokens.cpp
1011
TensorMemoryAllocation.cpp
1112
TMALowering.cpp
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#include "triton/Analysis/Allocation.h"
2+
#include "triton/Analysis/Membar.h"
3+
#include "triton/Analysis/Utility.h"
4+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
5+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
6+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
7+
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// On Hopper+, async proxy is separate from generic proxy, so when shared memory
11+
// is the generic proxy to the async proxy we need to insert a fence to ensure
12+
// memory consistency.
13+
// This pass analyzes dependencies and will conservatively insert fences to
14+
// avoid race conditions between proxies. Async proxy is defined here:
15+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#async-proxy
16+
//
17+
// This pass runs after shared memory allocation, to make sure we insert fences
18+
// between ops accessing aliasing buffers if needed.
19+
//
20+
// We also run a fence insertion pass during optimization phase as it is easier
21+
// to insert fences at optimial location based on structured control flow.
22+
//
23+
//===----------------------------------------------------------------------===//
24+
25+
namespace mlir {
26+
namespace triton {
27+
namespace nvidia_gpu {
28+
29+
#define GEN_PASS_DEF_TRITONGPUPROXYFENCEINSERTION
30+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
31+
32+
namespace {
33+
34+
bool isAsyncProxyWrite(Operation *op) {
35+
return isa<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp,
36+
triton::nvidia_gpu::AsyncTMAGatherOp>(op);
37+
}
38+
39+
Value getSmemDest(Operation *op) {
40+
if (auto asyncTMACopyGlobalToLocalOp =
41+
dyn_cast<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(op)) {
42+
return asyncTMACopyGlobalToLocalOp.getResult();
43+
}
44+
if (auto asyncTMAGatherOp =
45+
dyn_cast<triton::nvidia_gpu::AsyncTMAGatherOp>(op)) {
46+
return asyncTMAGatherOp.getResult();
47+
}
48+
return Value();
49+
}
50+
51+
bool isAsyncProxyRead(Operation *op) {
52+
return isa<triton::nvidia_gpu::WarpGroupDotOp,
53+
triton::nvidia_gpu::TCGen5MMAOp,
54+
triton::nvidia_gpu::TCGen5MMAScaledOp,
55+
triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp,
56+
triton::nvidia_gpu::AsyncTMAScatterOp,
57+
triton::nvidia_gpu::AsyncTMAReduceOp>(op);
58+
}
59+
60+
bool ignoreOpForProxyFence(Operation *op) {
61+
return isAsyncProxyRead(op) || isAsyncProxyWrite(op) ||
62+
isa<triton::nvidia_gpu::ArriveBarrierOp,
63+
triton::nvidia_gpu::TMEMCopyOp, triton::nvidia_gpu::WaitBarrierOp,
64+
triton::nvidia_gpu::InitBarrierOp,
65+
triton::nvidia_gpu::InvalBarrierOp>(op);
66+
}
67+
68+
bool filterFn(Operation *op, Operation *other) {
69+
return ignoreOpForProxyFence(other);
70+
}
71+
72+
//===----------------------------------------------------------------------===//
73+
// Proxy Fence Analysis
74+
//===----------------------------------------------------------------------===//
75+
class ProxyFenceAnalysis : public MembarOrFenceAnalysis {
76+
77+
public:
78+
ProxyFenceAnalysis() = default;
79+
explicit ProxyFenceAnalysis(Allocation *allocation, MembarFilterFn filter)
80+
: MembarOrFenceAnalysis(allocation, filter) {}
81+
82+
private:
83+
/// Updates the BlockInfo operation based on the operation.
84+
virtual void update(Operation *operation, BlockInfo *blockInfo,
85+
FuncBlockInfoMapT *funcBlockInfoMap,
86+
OpBuilder *builder) override;
87+
88+
void insertFence(Operation *operation, OpBuilder *builder);
89+
};
90+
91+
void ProxyFenceAnalysis::insertFence(Operation *op, OpBuilder *builder) {
92+
OpBuilder::InsertionGuard g(*builder);
93+
builder->create<triton::nvidia_gpu::FenceAsyncSharedOp>(op->getLoc(), false);
94+
}
95+
96+
void ProxyFenceAnalysis::update(Operation *op, BlockInfo *blockInfo,
97+
FuncBlockInfoMapT *funcBlockInfoMap,
98+
OpBuilder *builder) {
99+
if (isa<triton::nvidia_gpu::FenceAsyncSharedOp>(op)) {
100+
// If the current op is a fence, we clear previous reads and writes
101+
blockInfo->sync();
102+
return;
103+
}
104+
BlockInfo curBlockInfo;
105+
BlockInfo proxyBlockInfo;
106+
107+
auto scratchBufferId = Allocation::InvalidBufferId;
108+
if (isa<triton::CallOp>(op)) {
109+
// Inter-function dependencies
110+
auto callOpInterface = dyn_cast<CallOpInterface>(op);
111+
if (auto callee =
112+
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
113+
curBlockInfo = funcBlockInfoMap->lookup(callee);
114+
} else {
115+
// Intra-function dependencies
116+
if (auto memoryEffectOpInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
117+
// Explicit buffer
118+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
119+
effectInstances;
120+
memoryEffectOpInterface.getEffects(effectInstances);
121+
for (auto effectInstance : effectInstances) {
122+
if (auto value = effectInstance.getValue()) {
123+
for (auto bufferId : allocation->getBufferIds(value)) {
124+
if (bufferId != Allocation::InvalidBufferId) {
125+
// TODO: handle proxy read cases. Those are currently handled in
126+
// FenceInsertionPass where it can generate better placement for
127+
// the fence. But we should support a safe fallback here.
128+
if (isAsyncProxyWrite(op)) {
129+
if (value == getSmemDest(op)) {
130+
proxyBlockInfo
131+
.syncWriteIntervals[allocation->getAllocatedInterval(
132+
bufferId)]
133+
.insert(op);
134+
}
135+
} else if (isa<MemoryEffects::Write>(
136+
effectInstance.getEffect())) {
137+
curBlockInfo
138+
.syncWriteIntervals[allocation->getAllocatedInterval(
139+
bufferId)]
140+
.insert(op);
141+
} else if (isa<MemoryEffects::Read>(effectInstance.getEffect())) {
142+
curBlockInfo
143+
.syncReadIntervals[allocation->getAllocatedInterval(
144+
bufferId)]
145+
.insert(op);
146+
}
147+
}
148+
}
149+
}
150+
}
151+
}
152+
scratchBufferId = allocation->getBufferId(op);
153+
}
154+
155+
// Scratch buffer operations consist of a series of shared memory operations
156+
// starting from a shared memory write, followed by a series of shared memory
157+
// read/write operations, mark them as a read.
158+
if (scratchBufferId != Allocation::InvalidBufferId) {
159+
auto interval = allocation->getAllocatedInterval(scratchBufferId);
160+
curBlockInfo.syncReadIntervals[interval].insert(op);
161+
}
162+
if (isAsyncProxyWrite(op) || isAsyncProxyRead(op)) {
163+
if (proxyBlockInfo.isIntersected(*blockInfo, filter)) {
164+
builder->setInsertionPoint(op);
165+
insertFence(op, builder);
166+
blockInfo->sync();
167+
}
168+
}
169+
170+
// Update the region info, even if barrier is inserted, we have to maintain
171+
// the current op's read/write buffers.
172+
blockInfo->join(curBlockInfo);
173+
}
174+
} // namespace
175+
176+
struct ProxyFenceInsertionPass
177+
: public impl::TritonGPUProxyFenceInsertionBase<ProxyFenceInsertionPass> {
178+
179+
public:
180+
using impl::TritonGPUProxyFenceInsertionBase<
181+
ProxyFenceInsertionPass>::TritonGPUProxyFenceInsertionBase;
182+
void runOnOperation() override {
183+
// Only insert fences for compute capability 9.0
184+
if (computeCapability < 90)
185+
return;
186+
ModuleOp mod = getOperation();
187+
ModuleAllocation allocation(mod);
188+
ModuleMembarOrFenceAnalysis<ProxyFenceAnalysis> analysis(&allocation,
189+
filterFn);
190+
analysis.run();
191+
}
192+
};
193+
194+
} // namespace nvidia_gpu
195+
} // namespace triton
196+
} // namespace mlir

0 commit comments

Comments
 (0)