Skip to content

Commit 9f918cd

Browse files
Merge OpenAI Triton commit 75eed88 (#4809)
This PR change the Triton base from 167ed28 to 75eed88 (Jul 22). Pass rate: 98.62%
2 parents 9f1ebab + a3ddde6 commit 9f918cd

File tree

34 files changed

+1031
-1229
lines changed

34 files changed

+1031
-1229
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,19 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
221221
// [FIXME LL] Kill this function
222222
SmallVector<unsigned> getShapePerCTATile(RankedTensorType layout);
223223

224-
// Returns the "logical" shape per CTA
224+
// Returns the "logical" shape per CTA.
225+
// When shape and CTASplitNum have different number of dimensions, we assume
226+
// only the last N between common dimensions are split.
227+
// Example1: shape = [2, 4, 8], CTASplitNum = [2, 2], ret = [2, 2, 4].
228+
// It can be caused by pipelining.
229+
// Example2: shape = [2, 4], CTASplitNum = [2, 2, 2], ret = [1, 2].
230+
// It can be caused by memory slicing.
225231
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
226232
ArrayRef<int64_t> shape);
227233
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
228234
SmallVector<int64_t> getShapePerCTA(Type type);
229235

230-
// Returns the shape per CTA, which is "physically" allocated
236+
// Returns the shape per CTA, which is "physically" allocated.
231237
// Such shapes may be bigger than the logical one due to, for example, padding
232238
// in shared memory.
233239
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def TritonGPURewritePartitionDependencies : Pass<"tritongpu-rewrite-partition-de
130130
"mlir::triton::gpu::TritonGPUDialect",
131131
"mlir::scf::SCFDialect",
132132
"mlir::arith::ArithDialect",
133-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
133+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
134+
"mlir::triton::nvws::NVWSDialect"
134135
];
135136
}
136137

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -293,34 +293,22 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
293293
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
294294
ArrayRef<int64_t> shape) {
295295
unsigned rank = shape.size();
296+
auto splitNum = llvm::to_vector(CTASplitNum);
297+
if (splitNum.size() <= rank) { // pipelining
298+
splitNum.insert(splitNum.begin(), rank - splitNum.size(), 1);
299+
} else { // memory slicing
300+
splitNum =
301+
llvm::to_vector(llvm::drop_begin(splitNum, splitNum.size() - rank));
302+
}
296303
SmallVector<int64_t> shapePerCTA(rank);
297304
for (unsigned i = 0; i < rank; ++i) {
298-
unsigned splitNum = std::min<unsigned>(shape[i], CTASplitNum[i]);
299-
shapePerCTA[i] = shape[i] / splitNum;
305+
shapePerCTA[i] = shape[i] / std::min<unsigned>(shape[i], splitNum[i]);
300306
}
301307
return shapePerCTA;
302308
}
303309

304310
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape) {
305-
if (mlir::isa<SharedEncodingTrait>(layout)) {
306-
// Special logic for pipeline pass, where shape is 3D and CTALayout is 2D.
307-
// The first dim of shape is numStages. This is a work around, otherwise
308-
// too many places would have to be modified in pipeline pass. Maybe we
309-
// need to refactor this logic in the future.
310-
auto CTASplitNum = cast<LayoutEncodingTrait>(layout).getCTASplitNum();
311-
if (shape.size() == CTASplitNum.size() + 1) {
312-
auto res = getShapePerCTA(CTASplitNum, shape.drop_front());
313-
res.insert(res.begin(), shape.front());
314-
return res;
315-
}
316-
}
317-
SmallVector<unsigned> splitNum = getCTASplitNum(layout);
318-
if (auto tmem = dyn_cast<nvidia_gpu::TensorMemoryEncodingAttr>(layout)) {
319-
if (shape.size() > splitNum.size()) {
320-
splitNum.insert(splitNum.begin(), shape.size() - splitNum.size(), 1);
321-
}
322-
}
323-
return getShapePerCTA(splitNum, shape);
311+
return getShapePerCTA(getCTASplitNum(layout), shape);
324312
}
325313

326314
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 1 addition & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -148,94 +148,6 @@ class SinkTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
148148
}
149149
};
150150

151-
// Combine back TMEM alloc and store. This is equivalent but gives us a more
152-
// canonical form to do further optimizations.
153-
class CombineTMEMStoreAndAlloc : public OpRewritePattern<TMEMTokenStoreOp> {
154-
public:
155-
using OpRewritePattern::OpRewritePattern;
156-
157-
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
158-
PatternRewriter &rewriter) const override {
159-
if (!matchPattern(store.getPred(), m_One()))
160-
return failure();
161-
auto alloc = store.getDep().getDefiningOp<TMEMTokenAllocOp>();
162-
if (!alloc)
163-
return failure();
164-
if (alloc->getBlock() != store->getBlock())
165-
return failure();
166-
alloc.getSrcMutable().assign(store.getSrc());
167-
rewriter.replaceOp(store, alloc.getToken());
168-
return success();
169-
}
170-
};
171-
172-
// Hoists a tmem alloc outside an if op like this:
173-
// %0 = scf.if {
174-
// %1, %token0 = tmem.alloc %init
175-
// ...
176-
// %2 = tmem.load %1, %token1
177-
// scf.yield %2
178-
// } else {
179-
// scf.yield %init
180-
// }
181-
// ->
182-
// %a, %token0 = tmem.alloc %init
183-
// %token2 = scf.if {
184-
//
185-
// ...
186-
// scf.yield %token1
187-
// } else {
188-
// scf.yield %token0
189-
// }
190-
// %2 = tmem.load %a, %token2
191-
class HoistTMEMAllocOutOfIf : public OpRewritePattern<ttng::TMEMAllocOp> {
192-
public:
193-
using OpRewritePattern::OpRewritePattern;
194-
195-
LogicalResult matchAndRewrite(ttng::TMEMAllocOp alloc,
196-
PatternRewriter &rewriter) const override {
197-
if (!alloc.getToken())
198-
return failure();
199-
Value init = alloc.getSrc();
200-
if (!init)
201-
return failure();
202-
auto ifOp = dyn_cast<scf::IfOp>(alloc->getParentOp());
203-
if (!ifOp)
204-
return failure();
205-
auto thenOp = ifOp.thenBlock()->getTerminator();
206-
auto elseOp = ifOp.elseBlock()->getTerminator();
207-
SmallVector<int> yieldArgs;
208-
for (auto [thenOperand, elseOperand] :
209-
llvm::zip(thenOp->getOpOperands(), elseOp->getOpOperands())) {
210-
auto load = thenOperand.get().getDefiningOp<TMEMTokenLoadOp>();
211-
if (!load || load.getSrc() != alloc.getResult())
212-
continue;
213-
if (elseOperand.get() != init)
214-
continue;
215-
yieldArgs.push_back(thenOperand.getOperandNumber());
216-
}
217-
if (yieldArgs.empty())
218-
return failure();
219-
// Since init is used in the else terminator we know that it dominates the
220-
// if op.
221-
alloc->moveBefore(ifOp);
222-
rewriter.setInsertionPointAfter(ifOp);
223-
for (int argNo : yieldArgs) {
224-
auto load =
225-
cast<TMEMTokenLoadOp>(thenOp->getOperand(argNo).getDefiningOp());
226-
auto newLoad = cast<TMEMTokenLoadOp>(rewriter.clone(*load));
227-
rewriter.modifyOpInPlace(ifOp, [&] {
228-
ifOp->getResult(argNo).replaceAllUsesWith(newLoad.getResult());
229-
newLoad.getDepMutable().assign(ifOp->getResult(argNo));
230-
thenOp->setOperand(argNo, load.getToken());
231-
elseOp->setOperand(argNo, alloc.getToken());
232-
ifOp->getResult(argNo).setType(newLoad.getToken().getType());
233-
});
234-
}
235-
return success();
236-
}
237-
};
238-
239151
// Remove loop-carried tensor dependencies if they are fed immediately into a
240152
// TMEM store by pulling the store into the previous iteration.
241153
class RotateTMEMStoreInLoop : public OpRewritePattern<TMEMTokenStoreOp> {
@@ -500,29 +412,11 @@ struct HoistTMEMAlloc
500412
mlir::RewritePatternSet patterns(&getContext());
501413
patterns.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
502414
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
503-
SinkTMEMLoad, RemoveUnusedTMEMLoad, CombineTMEMStoreAndAlloc,
504-
HoistTMEMAllocOutOfIf>(&getContext());
415+
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext());
505416
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
506417
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
507418
llvm_unreachable("Failed to hoist tmem_store");
508419
}
509-
510-
// TODO: currently some code assumes that a mutable tmem alloc doesn't have
511-
// an initial value. As a workaround we break up the op in order to keep
512-
// this form for the downstream passes. We should remove this once the
513-
// downstread passes are fixed.
514-
m.walk([&](ttng::TMEMAllocOp alloc) {
515-
if (alloc.getType().getMutableMemory() && alloc.getSrc()) {
516-
OpBuilder builder(alloc);
517-
builder.setInsertionPointAfter(alloc);
518-
auto store = builder.create<ttng::TMEMStoreOp>(
519-
alloc.getLoc(), builder.getType<AsyncTokenType>(),
520-
alloc.getResult(), alloc.getToken(), alloc.getSrc(),
521-
builder.create<arith::ConstantIntOp>(alloc.getLoc(), 1, 1));
522-
alloc.getToken().replaceAllUsesExcept(store.getToken(), store);
523-
alloc.getSrcMutable().clear();
524-
}
525-
});
526420
}
527421
};
528422

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Pass/Pass.h"
44
#include "mlir/Pass/PassManager.h"
55
#include "mlir/Transforms/Passes.h"
6+
#include "third_party/nvidia/include/Dialect/NVWS/Transforms/Passes.h"
67
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
78
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
89

@@ -42,6 +43,8 @@ void AutomaticWarpSpecialization::runOnOperation() {
4243
pm.addPass(createSCCPPass());
4344
pm.addPass(createCSEPass());
4445
pm.addPass(createTritonGPUPartitionLoops());
46+
pm.addPass(createNVWSLowerAref());
47+
pm.addPass(createNVWSLowerWarpGroup());
4548
if (failed(runPipeline(pm, getOperation())))
4649
return signalPassFailure();
4750

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -191,57 +191,6 @@ LogicalResult WarpSchedule::verify(scf::ForOp loop) const {
191191
if (failed)
192192
return failure();
193193

194-
// Within a loop iteration, the partitions must form a DAG. For example, the
195-
// following is invalid:
196-
//
197-
// scf.for %i = %lb to %ub step %step
198-
// %0 = op_a() {ttg.partition = 0}
199-
// %1 = op_b(%0) {ttg.partition = 1}
200-
// op_c(%1) {ttg.partition = 0}
201-
//
202-
PartitionGraph graph(loop, *this);
203-
for (auto it = llvm::scc_begin(graph); !it.isAtEnd(); ++it) {
204-
if (!it.hasCycle())
205-
continue;
206-
InFlightDiagnostic diag =
207-
mlir::emitWarning(loop.getLoc(), "warp schedule contains a cycle");
208-
for (auto [node, use] : *it) {
209-
assert(use && "already checked that the root partition has no ancestors");
210-
diag.attachNote(use->getOwner()->getLoc())
211-
<< "operation in partition #" << node->partition->getIndex()
212-
<< " uses value defined in partition #"
213-
<< opToPartition.at(use->get().getDefiningOp())->getIndex();
214-
}
215-
return failure();
216-
}
217-
218-
// Each partition's stage must be strictly less than all of its consumers plus
219-
// the distance.
220-
for (Partition &partition : getPartitions()) {
221-
bool failed = false;
222-
auto callback = [&](OpResult output, OpOperand &use, unsigned distance) {
223-
Operation *user = loop.getBody()->findAncestorOpInBlock(*use.getOwner());
224-
const Partition *consumer = opToPartition.at(user);
225-
if (partition.getStage() < consumer->getStage() + distance)
226-
return;
227-
InFlightDiagnostic diag =
228-
mlir::emitWarning(loop.getLoc(), "partition #")
229-
<< partition.getIndex() << " has stage " << partition.getStage()
230-
<< " but is consumed by partition #" << consumer->getIndex()
231-
<< " with stage " << consumer->getStage() << " at distance "
232-
<< distance;
233-
diag.attachNote(use.getOwner()->getLoc())
234-
<< "use of value defined in partition #" << partition.getIndex()
235-
<< " at " << distance << " iterations in the future";
236-
diag.attachNote(output.getLoc())
237-
<< "value defined here in partition #" << partition.getIndex();
238-
failed = true;
239-
};
240-
iterateUses(loop, &partition, callback);
241-
if (failed)
242-
return failure();
243-
}
244-
245194
return success();
246195
}
247196

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,10 +499,4 @@ void PartitionLoops::runOnOperation() {
499499
if (failed(partitionLoop(loop)))
500500
return signalPassFailure();
501501
}
502-
503-
OpPassManager pm;
504-
pm.addPass(mlir::triton::createNVWSLowerWarpGroup());
505-
506-
if (failed(runPipeline(pm, getOperation())))
507-
return signalPassFailure();
508502
}

0 commit comments

Comments
 (0)