Skip to content

Commit 1572ee6

Browse files
authored
[Pipelining] Fix TMA descriptor pipelining when only TMA store is piplined (#7141)
When we only pipelining TMA store we need to make sure we double buffer the TMA descriptor. We also need to make sure the store wait op is waiting until the store is done and not only until the smem is read.
1 parent 9695bae commit 1572ee6

File tree

15 files changed

+365
-265
lines changed

15 files changed

+365
-265
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ dev-install-llvm:
107107

108108
.PHONY: golden-samples
109109
golden-samples: triton-opt
110-
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \
110+
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \
111111
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
112112
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
113113
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,6 @@ size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
423423
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
424424
StringRef content);
425425

426-
inline bool isKernel(FunctionOpInterface funcOp) {
427-
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
428-
}
429-
430426
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp);
431427

432428
Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ Value getLastInductionValue(OpBuilder &b, scf::ForOp loop);
182182

183183
MakeTensorPtrOp getMakeTensorPtrOp(Value v);
184184

185+
bool isHostSideDescriptor(Value v);
186+
187+
bool isKernel(FunctionOpInterface funcOp);
188+
185189
} // namespace triton
186190
} // namespace mlir
187191

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1919
static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
22+
class CoarseSchedule;
2223

2324
//===----------------------------------------------------------------------===//
2425
// Hoisting Utilities
@@ -138,6 +139,12 @@ createSingleBufferView(OpBuilder &builder, Value alloc, Value idx);
138139
TypedValue<triton::gpu::MemDescType>
139140
createSingleBufferView(OpBuilder &builder, Value alloc, int idx);
140141

142+
Value createIncrementModulo(OpBuilder &builder, Location loc, Value counter,
143+
Value modulus, Value zero, Value one,
144+
Value *outWrapCond = nullptr);
145+
146+
scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule);
147+
141148
} // namespace triton
142149
} // namespace mlir
143150

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
33

44
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/IR/ImplicitLocOpBuilder.h"
56
#include "mlir/Support/LLVM.h"
67
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"
78
#include "llvm/ADT/ArrayRef.h"
@@ -164,6 +165,32 @@ class CoarseSchedule {
164165
// the same stage and ordering cluster as the anchor op.
165166
void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule);
166167

168+
class OpBuilderForStage : public mlir::ImplicitLocOpBuilder,
169+
public OpBuilder::Listener {
170+
public:
171+
explicit OpBuilderForStage(Location loc, Operation *op,
172+
CoarseSchedule &schedule)
173+
: ImplicitLocOpBuilder(loc, op, this), schedule(schedule) {
174+
if (auto it = schedule.find(op); it != schedule.end())
175+
std::tie(stage, cluster) = it->second;
176+
}
177+
178+
void setStageCluster(std::pair<int, CoarseSchedule::Cluster> stageCluster) {
179+
stage = stageCluster.first;
180+
cluster = stageCluster.second;
181+
}
182+
183+
void notifyOperationInserted(Operation *op, InsertPoint previous) {
184+
if (stage && cluster)
185+
schedule.insert(op, *stage, *cluster);
186+
}
187+
188+
private:
189+
std::optional<int> stage;
190+
std::optional<CoarseSchedule::Cluster> cluster;
191+
CoarseSchedule &schedule;
192+
};
193+
167194
} // namespace triton
168195
} // namespace mlir
169196
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
6666
// 1. Modify the function type to add the new arguments.
6767
auto funcTy = funcOp.getFunctionType();
6868
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
69-
bool isKernel = LLVM::isKernel(funcOp);
69+
bool isKernel = triton::isKernel(funcOp);
7070
if (isKernel) {
7171
for (auto i : llvm::seq(amendedInputTy.size())) {
7272
if (isa<TensorDescType>(amendedInputTy[i])) {
@@ -111,7 +111,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
111111
// Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM
112112
// attributes.
113113
static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) {
114-
const bool isKernel = LLVM::isKernel(llvmFuncOp);
114+
const bool isKernel = triton::isKernel(llvmFuncOp);
115115
for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) {
116116
const auto attrs = llvmFuncOp.getArgAttrDict(i);
117117
if (!attrs) {
@@ -161,7 +161,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
161161

162162
auto ctx = funcOp->getContext();
163163

164-
if (LLVM::isKernel(funcOp)) {
164+
if (triton::isKernel(funcOp)) {
165165
// Set an attribute to indicate this function is a kernel entry.
166166
newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(),
167167
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));

lib/Dialect/Triton/IR/Utility.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,17 @@ Value tt::getLastInductionValue(OpBuilder &b, scf::ForOp loop) {
103103
loc, b.create<arith::DivSIOp>(loc, diff, loop.getStep()), loop.getStep());
104104
return b.create<arith::AddIOp>(loc, ceilStep, loop.getLowerBound());
105105
}
106+
107+
bool tt::isKernel(FunctionOpInterface funcOp) {
108+
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
109+
}
110+
111+
bool tt::isHostSideDescriptor(Value v) {
112+
auto arg = dyn_cast<BlockArgument>(v);
113+
if (!arg)
114+
return false;
115+
auto funcOp = dyn_cast<FunctionOpInterface>(arg.getOwner()->getParentOp());
116+
if (!funcOp)
117+
return false;
118+
return tt::isKernel(funcOp);
119+
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 0 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "mlir/Dialect/UB/IR/UBOps.h"
22
#include "mlir/IR/Dominance.h"
3-
#include "mlir/IR/ImplicitLocOpBuilder.h"
43
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
54
#include "triton/Analysis/Utility.h"
65
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -38,32 +37,6 @@ namespace {
3837
// UTILS
3938
/////////////////////////////
4039

41-
class OpBuilderForStage : public ImplicitLocOpBuilder,
42-
public OpBuilder::Listener {
43-
public:
44-
explicit OpBuilderForStage(Location loc, Operation *op,
45-
CoarseSchedule &schedule)
46-
: ImplicitLocOpBuilder(loc, op, this), schedule(schedule) {
47-
if (auto it = schedule.find(op); it != schedule.end())
48-
std::tie(stage, cluster) = it->second;
49-
}
50-
51-
void setStageCluster(std::pair<int, CoarseSchedule::Cluster> stageCluster) {
52-
stage = stageCluster.first;
53-
cluster = stageCluster.second;
54-
}
55-
56-
void notifyOperationInserted(Operation *op, InsertPoint previous) {
57-
if (stage && cluster)
58-
schedule.insert(op, *stage, *cluster);
59-
}
60-
61-
private:
62-
std::optional<int> stage;
63-
std::optional<CoarseSchedule::Cluster> cluster;
64-
CoarseSchedule &schedule;
65-
};
66-
6740
int getSelfLatencyFromAttr(Operation *op) {
6841
auto module = op->getParentOfType<ModuleOp>();
6942
auto helper = TritonDialect::getLoaded(module)->getSelfLatencyAttrHelper();
@@ -207,17 +180,6 @@ int getDefUseStageDiff(Operation *op, scf::ForOp forOp,
207180
return useStage.value() - defStage;
208181
}
209182

210-
Value createIncrementModulo(OpBuilder &builder, Location loc, Value counter,
211-
Value modulus, Value zero, Value one,
212-
Value *outWrapCond = nullptr) {
213-
Value addOne = builder.create<arith::AddIOp>(loc, counter, one);
214-
Value outOfRangeCond = builder.create<arith::CmpIOp>(
215-
loc, arith::CmpIPredicate::sge, addOne, modulus);
216-
if (outWrapCond)
217-
*outWrapCond = outOfRangeCond;
218-
return builder.create<arith::SelectOp>(loc, outOfRangeCond, zero, addOne);
219-
}
220-
221183
void replaceAllUsesDominatedBy(Operation *domOp, Value newValue, Value oldValue,
222184
DominanceInfo &domInfo) {
223185
if (newValue == oldValue)
@@ -644,132 +606,6 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
644606
return forOp;
645607
}
646608

647-
/////////////////////////////
648-
// LOWER TMA DESCRIPTORS
649-
/////////////////////////////
650-
651-
void allocTMABuffers(scf::ForOp forOp,
652-
llvm::MapVector<Operation *, Value> &tmaBufferMapping,
653-
int maxStage) {
654-
IRRewriter rewriter(forOp);
655-
656-
// Create a multi-buffered allocation for each MakeTensorDescOp call in the
657-
// loop
658-
forOp.walk([&](tt::MakeTensorDescOp op) {
659-
// TODO peter: walk to loop yield to find the init value if this is a
660-
// loop-carried value. That would save us from allocating another buffer
661-
// just for the init value
662-
auto loc = op.getLoc();
663-
Value alloc = rewriter.create<triton::gpu::GlobalScratchAllocOp>(
664-
loc, triton::getPointerType(rewriter.getI8Type()),
665-
maxStage * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN);
666-
tmaBufferMapping[op.getOperation()] = alloc;
667-
});
668-
}
669-
670-
Value subviewTMADescriptor(OpBuilder &builder, Location loc, Value alloc,
671-
Value counter) {
672-
Value tmaSizeVal =
673-
builder.create<arith::ConstantIntOp>(loc, ttng::TMA_SIZE_BYTES, 32);
674-
Value offset = builder.create<arith::MulIOp>(loc, tmaSizeVal, counter);
675-
return builder.create<triton::AddPtrOp>(loc, alloc.getType(), alloc, offset);
676-
}
677-
678-
LogicalResult rewriteTMABufferUpdates(
679-
scf::ForOp forOp,
680-
const llvm::MapVector<Operation *, Value> &tmaBufferMapping,
681-
ArrayRef<BlockArgument> tmaCounters, int numBuffers, Value one, Value zero,
682-
CoarseSchedule &schedule) {
683-
assert(tmaBufferMapping.size() == tmaCounters.size());
684-
685-
Value numBuffersVal = mlir::OpBuilder(forOp).create<arith::ConstantIntOp>(
686-
forOp.getLoc(), numBuffers, 32);
687-
688-
for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) {
689-
auto &[op, alloc] = pair;
690-
691-
// Rewriter MakeTensorDescOp as writing a TMA descriptor
692-
auto makeDescOp = cast<tt::MakeTensorDescOp>(op);
693-
694-
OpBuilderForStage builder(makeDescOp.getLoc(), makeDescOp, schedule);
695-
696-
BlockArgument counter = tmaCounters[iOp];
697-
Value nextBuf =
698-
subviewTMADescriptor(builder, builder.getLoc(), alloc, counter);
699-
if (failed(ttng::createTMADesc(nextBuf, makeDescOp, builder))) {
700-
return failure();
701-
}
702-
builder.create<ttng::TensormapFenceproxyAcquireOp>(nextBuf);
703-
Value nextDesc = builder.create<ttng::ReinterpretTensorDescOp>(
704-
makeDescOp.getType(), nextBuf);
705-
706-
makeDescOp.getResult().replaceAllUsesWith(nextDesc);
707-
708-
// Increment the buffer index counter
709-
Value nextCounter = createIncrementModulo(
710-
builder, builder.getLoc(), counter, numBuffersVal, zero, one);
711-
712-
// If we are in a (potentially nested) if region, propagate the counter
713-
// up to the main for op body scope
714-
IRRewriter rewriter(forOp);
715-
nextCounter =
716-
sinkValueRedefinition(rewriter, counter, nextCounter, op->getBlock());
717-
718-
// Finally, rewrite the loop level yield
719-
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
720-
forYield.setOperand(counter.getArgNumber() - 1, nextCounter);
721-
}
722-
return success();
723-
}
724-
725-
scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule) {
726-
llvm::MapVector<Operation *, Value> tmaBufferMapping;
727-
int maxStage = schedule.getNumStages() - 1;
728-
for (auto &op : forOp.getBody()->without_terminator()) {
729-
if (auto wgMmaOp = dyn_cast<ttng::WarpGroupDotOp>(&op)) {
730-
// Hopper only: Add one more buffer slice if there is a WarpGroupDotOp,
731-
// as if it will be pipelined, we will effectively make the pipeline
732-
// one stage longer.
733-
maxStage += 1;
734-
break;
735-
}
736-
}
737-
allocTMABuffers(forOp, tmaBufferMapping, maxStage);
738-
if (tmaBufferMapping.empty())
739-
return forOp;
740-
741-
IRRewriter builder(forOp);
742-
Location loc = forOp.getLoc();
743-
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
744-
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32);
745-
SmallVector<Value> newOperands;
746-
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
747-
// Create one counter per TMA buffer. This allows the descriptors to be
748-
// updated independently without needing to write duplicate of existing tma
749-
// descriptors.
750-
unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size();
751-
for (int i = 0; i < tmaBufferMapping.size(); ++i) {
752-
newOperands.push_back(zero);
753-
}
754-
755-
forOp = addIterArgsToLoop(builder, forOp, newOperands);
756-
757-
auto tmaCounters = ArrayRef<BlockArgument>(forOp.getBody()->getArguments())
758-
.slice(tmaCounterArgsStartIdx);
759-
760-
// Update yield op with temporary yield values
761-
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
762-
for (unsigned i = 0; i < newOperands.size(); ++i) {
763-
forYield.getResultsMutable().append(newOperands[i]);
764-
}
765-
766-
if (failed(rewriteTMABufferUpdates(forOp, tmaBufferMapping, tmaCounters,
767-
maxStage, one, zero, schedule))) {
768-
llvm_unreachable("Failed to rewrite TMA ops");
769-
}
770-
return forOp;
771-
}
772-
773609
/////////////////////////////
774610
// LOWER MMA
775611
/////////////////////////////

0 commit comments

Comments
 (0)