Skip to content

Commit 2b612ce

Browse files
Merge OpenAI Triton commit 09d5113 (#4586)
This PR change the Triton base from 6c3d943 to 09d5113 (Jun 23). Pass rate: 97.12%
2 parents 63aeb2b + 35023b0 commit 2b612ce

File tree

43 files changed

+1292
-255
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1292
-255
lines changed

include/triton/Analysis/Membar.h

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ struct BlockInfo {
9595
//===----------------------------------------------------------------------===//
9696
// Shared Memory Barrier Analysis
9797
//===----------------------------------------------------------------------===//
98-
class MembarAnalysis {
98+
99+
// Common class to analyze membar and fence placement.
100+
class MembarOrFenceAnalysis {
99101
using VirtualBlock = std::pair<Block *, Block::iterator>;
100102

101103
public:
@@ -113,15 +115,15 @@ class MembarAnalysis {
113115
/// a shared memory read. If the temporary storage is written but not read,
114116
/// it is considered as the problem of the operation itself but not the membar
115117
/// analysis.
116-
MembarAnalysis() = default;
117-
explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter)
118+
MembarOrFenceAnalysis() = default;
119+
explicit MembarOrFenceAnalysis(Allocation *allocation, MembarFilterFn filter)
118120
: allocation(allocation), filter(filter) {}
119121

120122
/// Runs the membar analysis to the given operation, inserts a barrier if
121123
/// necessary.
122124
void run(FuncBlockInfoMapT &funcBlockInfoMap);
123125

124-
private:
126+
protected:
125127
/// Applies the barrier analysis based on the SCF dialect, in which each
126128
/// region has a single basic block only.
127129
/// Example:
@@ -139,30 +141,44 @@ class MembarAnalysis {
139141
void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap,
140142
OpBuilder *builder);
141143

142-
/// Updates the BlockInfo operation based on the operation.
143-
void update(Operation *operation, BlockInfo *blockInfo,
144-
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder);
145-
146144
/// Collects the successors of the terminator
147145
void visitTerminator(Operation *operation,
148146
SmallVector<VirtualBlock> &successors);
149147

150-
void insertBarrier(Operation *operation, OpBuilder *builder);
148+
/// Updates the BlockInfo operation based on the operation.
149+
virtual void update(Operation *operation, BlockInfo *blockInfo,
150+
FuncBlockInfoMapT *funcBlockInfoMap,
151+
OpBuilder *builder) = 0;
151152

152-
private:
153153
Allocation *allocation = nullptr;
154154
MembarFilterFn filter = nullptr;
155155
};
156156

157+
class MembarAnalysis : public MembarOrFenceAnalysis {
158+
public:
159+
MembarAnalysis() = default;
160+
explicit MembarAnalysis(Allocation *allocation, MembarFilterFn filter)
161+
: MembarOrFenceAnalysis(allocation, filter) {}
162+
163+
private:
164+
/// Updates the BlockInfo operation based on the operation.
165+
virtual void update(Operation *operation, BlockInfo *blockInfo,
166+
FuncBlockInfoMapT *funcBlockInfoMap,
167+
OpBuilder *builder) override;
168+
169+
void insertBarrier(Operation *operation, OpBuilder *builder);
170+
};
171+
157172
/// Postorder traversal on the callgraph to insert membar instructions
158173
/// of each function.
159174
/// Each function maintains a BlockInfo map that includes all potential buffers
160175
/// after returning. This way users do not have to explicitly insert membars
161176
/// before and after function calls, but might be a bit conservative.
162-
class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
177+
template <typename AnalysisType>
178+
class ModuleMembarOrFenceAnalysis : public CallGraph<BlockInfo> {
163179
public:
164-
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation,
165-
MembarFilterFn filter = nullptr)
180+
ModuleMembarOrFenceAnalysis(ModuleAllocation *moduleAllocation,
181+
MembarFilterFn filter = nullptr)
166182
: CallGraph<BlockInfo>(moduleAllocation->getModuleOp()),
167183
moduleAllocation(moduleAllocation), filter(filter) {}
168184

@@ -175,7 +191,7 @@ class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
175191
auto *allocation = moduleAllocation->getFuncData(funcOp);
176192
auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo());
177193
if (inserted) {
178-
MembarAnalysis analysis(allocation, filter);
194+
AnalysisType analysis(allocation, filter);
179195
analysis.run(funcMap);
180196
}
181197
});
@@ -186,6 +202,8 @@ class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
186202
MembarFilterFn filter;
187203
};
188204

205+
typedef ModuleMembarOrFenceAnalysis<MembarAnalysis> ModuleMembarAnalysis;
206+
189207
} // namespace mlir
190208

191209
#endif // TRITON_ANALYSIS_MEMBAR_H

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ void replaceUsesWithLocalLoad(
272272
// after converting loads into async loads.
273273
bool comesFromLoadOrBlockArg(Value v);
274274

275+
// For structured control flow ops, returns the values associated with the
276+
// `resultIdx`th result.
277+
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx);
278+
275279
} // namespace mlir::triton
276280

277281
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@ def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp">
4141
}
4242

4343
def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::ModuleOp"> {
44+
let summary = "Insert fences across generic and async proxy.";
45+
46+
let description = [{
47+
This pass is to insert memory fences to ensure that memory operations are
48+
properly ordered across generic and async operations.
49+
This pass inserts fences at optimized location.
50+
There is a pass later to handle all the functional requirements
51+
}];
52+
53+
let dependentDialects = [
54+
"mlir::triton::gpu::TritonGPUDialect",
55+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
56+
];
57+
58+
let options = [
59+
Option<"computeCapability", "compute-capability",
60+
"int32_t", /*default*/"90",
61+
"device compute capability">
62+
];
63+
}
64+
65+
def TritonGPUProxyFenceInsertion : Pass<"triton-nvidia-gpu-proxy-fence-insertion", "mlir::ModuleOp"> {
4466
let summary = "Insert fences across generic and async proxy";
4567

4668
let description = [{

lib/Analysis/Membar.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
namespace mlir {
1010

11-
void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
11+
void MembarOrFenceAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
1212
FunctionOpInterface funcOp =
1313
dyn_cast<FunctionOpInterface>(allocation->getOperation());
1414
OpBuilder builder(funcOp.getContext());
1515
resolve(funcOp, &funcBlockInfoMap, &builder);
1616
}
1717

18-
void MembarAnalysis::resolve(FunctionOpInterface funcOp,
19-
FuncBlockInfoMapT *funcBlockInfoMap,
20-
OpBuilder *builder) {
18+
void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp,
19+
FuncBlockInfoMapT *funcBlockInfoMap,
20+
OpBuilder *builder) {
2121
// Initialize the blockList. Operations are organized into "virtual blocks",
2222
// which represent segments of straight-line code analyzed by each iteration
2323
// of the dataflow analysis. Virtual blocks abstract over both control flow
@@ -103,8 +103,8 @@ void MembarAnalysis::resolve(FunctionOpInterface funcOp,
103103
});
104104
}
105105

106-
void MembarAnalysis::visitTerminator(Operation *op,
107-
SmallVector<VirtualBlock> &successors) {
106+
void MembarOrFenceAnalysis::visitTerminator(
107+
Operation *op, SmallVector<VirtualBlock> &successors) {
108108
if (isa<BranchOpInterface>(op)) {
109109
// Collect the block successors of the branch.
110110
for (Block *successor : op->getSuccessors())

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ namespace gpu {
3333
#define GEN_PASS_DEF_TRITONGPUPIPELINE
3434
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
3535

36-
static void pipelineWgmma(ModuleOp moduleOp) {
36+
static void pipelineWgmma(ModuleOp moduleOp, unsigned numStages) {
3737
SmallVector<scf::ForOp> loops;
3838
moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
3939

4040
for (scf::ForOp forOp : loops) {
41-
mlir::triton::asyncLaunchDots(forOp);
41+
if (getNumStagesOrDefault(forOp, numStages) >= 1)
42+
mlir::triton::asyncLaunchDots(forOp);
4243
}
4344
}
4445

@@ -223,7 +224,6 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
223224

224225
void runOnOperation() override {
225226
ModuleOp moduleOp = getOperation();
226-
227227
// Transform the loop by introducing async operations to prepare it for
228228
// pipeline expansion.
229229
lowerLoops(moduleOp);
@@ -244,7 +244,7 @@ struct PipelinePass : public impl::TritonGPUPipelineBase<PipelinePass> {
244244
// Cleanup the IR from the pipeline attributes.
245245
removeAttributes(moduleOp);
246246

247-
pipelineWgmma(moduleOp);
247+
pipelineWgmma(moduleOp, numStages);
248248

249249
// schedule the waits
250250
mlir::triton::updateWaits(getOperation());

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,7 @@ void LayoutRematerialization::backwardRematerialization(
11931193
} else if (isa<arith::ConstantOp>(op)) {
11941194
// special-case: arith.constant has zero cost
11951195
continue;
1196-
} else if (isa<LoadOp>(op)) {
1196+
} else if (isa<LoadOp>(op) || isa<LocalLoadOp>(op)) {
11971197
// optimistically assume L1-cached:
11981198
for (Value result : op->getResults()) {
11991199
rematerialisationCost += 8 * getByteCount(result);
@@ -1208,6 +1208,12 @@ void LayoutRematerialization::backwardRematerialization(
12081208
for (Value result : op->getResults()) {
12091209
rematerialisationCost += multiplier * getByteCount(result);
12101210
}
1211+
} else if (isa<ReduceOp>(op)) {
1212+
// Reduce op introduce much cost.
1213+
auto reduceOp = dyn_cast<ReduceOp>(op);
1214+
ReduceOpHelper helper(reduceOp);
1215+
rematerialisationCost += helper.getIntraWarpSizeWithUniqueData();
1216+
rematerialisationCost += 8 * helper.getInterWarpSizeWithUniqueData();
12111217
}
12121218
}
12131219

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,4 +1586,36 @@ bool comesFromLoadOrBlockArg(Value v) {
15861586
isa<LoadOp, DescriptorLoadOp, DescriptorGatherOp>(v.getDefiningOp()));
15871587
}
15881588

1589+
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
1590+
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1591+
auto iterArg = forOp.getRegionIterArg(resultIdx);
1592+
auto result = forOp.getResult(resultIdx);
1593+
auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx);
1594+
auto initVal = forOp.getInitArgs()[resultIdx];
1595+
return {iterArg, result, yieldVal, initVal};
1596+
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
1597+
auto iterArg = whileOp.getBeforeArguments()[resultIdx];
1598+
auto result = whileOp.getResults()[resultIdx];
1599+
auto yieldVal =
1600+
whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx);
1601+
auto initVal = whileOp.getOperands()[resultIdx];
1602+
return {iterArg, result, iterArg, initVal};
1603+
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
1604+
SmallVector<Value> values;
1605+
for (auto &block : ifOp.getThenRegion().getBlocks()) {
1606+
auto terminator = block.getTerminator();
1607+
if (isa<scf::YieldOp>(terminator))
1608+
values.push_back(terminator->getOperands()[resultIdx]);
1609+
}
1610+
for (auto &block : ifOp.getElseRegion().getBlocks()) {
1611+
auto terminator = block.getTerminator();
1612+
if (isa<scf::YieldOp>(terminator))
1613+
values.push_back(terminator->getOperands()[resultIdx]);
1614+
}
1615+
values.push_back(ifOp->getResults()[resultIdx]);
1616+
return values;
1617+
}
1618+
return {};
1619+
}
1620+
15891621
} // namespace mlir::triton

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_triton_library(TritonNvidiaGPUTransforms
66
OptimizeTMemLayouts.cpp
77
PlanCTA.cpp
88
PromoteLHSToTMem.cpp
9+
ProxFenceInsertion.cpp
910
RemoveTMEMTokens.cpp
1011
TensorMemoryAllocation.cpp
1112
TMALowering.cpp

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -128,38 +128,6 @@ namespace nvidia_gpu {
128128

129129
namespace {
130130

131-
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
132-
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
133-
auto iterArg = forOp.getRegionIterArg(resultIdx);
134-
auto result = forOp.getResult(resultIdx);
135-
auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx);
136-
auto initVal = forOp.getInitArgs()[resultIdx];
137-
return {iterArg, result, yieldVal, initVal};
138-
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
139-
auto iterArg = whileOp.getBeforeArguments()[resultIdx];
140-
auto result = whileOp.getResults()[resultIdx];
141-
auto yieldVal =
142-
whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx);
143-
auto initVal = whileOp.getOperands()[resultIdx];
144-
return {iterArg, result, iterArg, initVal};
145-
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
146-
SmallVector<Value> values;
147-
for (auto &block : ifOp.getThenRegion().getBlocks()) {
148-
auto terminator = block.getTerminator();
149-
if (isa<scf::YieldOp>(terminator))
150-
values.push_back(terminator->getOperands()[resultIdx]);
151-
}
152-
for (auto &block : ifOp.getElseRegion().getBlocks()) {
153-
auto terminator = block.getTerminator();
154-
if (isa<scf::YieldOp>(terminator))
155-
values.push_back(terminator->getOperands()[resultIdx]);
156-
}
157-
values.push_back(ifOp->getResults()[resultIdx]);
158-
return values;
159-
}
160-
return {};
161-
}
162-
163131
const EncodingInfo *internEncoding(std::unordered_set<EncodingInfo> &encodings,
164132
EncodingInfo info) {
165133
return &*encodings.insert(info).first;

0 commit comments

Comments
 (0)