Skip to content

Commit 635435f

Browse files
[Pipeliner] Add support for pipelining loads with different latencies (#5460)
@pawelszczerbuk wrote the code. I just fixed a few things and added a test :) This generalizes the loop pipeliner infrastructure a bit to support loads with different latencies that are pipelined and multibuffered differently, allowing more fine-grained buffer allocation. The feature isn't exposed yet, but the PR also adds an attribute to the TMA load op allowing the user to manually specify the desired latency. --------- Co-authored-by: Pawel Szczerbuk <[email protected]>
1 parent 48468af commit 635435f

File tree

3 files changed

+277
-81
lines changed

3 files changed

+277
-81
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,23 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
183183
return loadOpToIndLevel;
184184
}
185185

186+
bool hasLatenciesAssigned(scf::ForOp forOp) {
187+
for (auto &op : forOp.getBody()->without_terminator()) {
188+
if (op.hasAttr("tt_latency"))
189+
return true;
190+
}
191+
return false;
192+
}
193+
194+
void assignUserProvidedLatencies(scf::ForOp forOp,
195+
DenseMap<Operation *, int> &opLatency) {
196+
for (auto &op : forOp.getBody()->without_terminator()) {
197+
if (auto latencyAttr = op.getAttr("tt_latency")) {
198+
opLatency[&op] = mlir::cast<IntegerAttr>(latencyAttr).getInt();
199+
}
200+
}
201+
}
202+
186203
} // namespace
187204

188205
// Look for load ops that directly or indirectly feed into dot ops. Based
@@ -212,6 +229,10 @@ DenseMap<Operation *, int> assignLatencies(ModuleOp moduleOp,
212229

213230
DenseMap<Operation *, int> opLatency;
214231
for (auto forOp : loops) {
232+
if (hasLatenciesAssigned(forOp)) {
233+
assignUserProvidedLatencies(forOp, opLatency);
234+
continue;
235+
}
215236
int numStages = getNumStagesOrDefault(forOp);
216237
bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName);
217238
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 112 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ static Operation *getFirstUseOfPipelinedLoad(Operation *loadOp) {
121121
static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
122122
Value insertIdx, Value extractIdx,
123123
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
124-
int numStages, int maxClusterId) {
124+
int maxClusterId) {
125125
int retCode = -1;
126126
OpBuilderWithStage builder(forOp);
127127
auto opPair = tt::getStageCluster(loadOp);
@@ -234,8 +234,7 @@ static void
234234
createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
235235
Value alloc, Value insertIdx, Value extractIdx,
236236
Value barrier, Operation *waitOp, Value phase,
237-
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
238-
int numStages) {
237+
llvm::MapVector<Operation *, LoadInfo> &loadToInfo) {
239238
assert(phase && "Phase value is required for TMA async copy.");
240239
OpBuilderWithStage builder(forOp);
241240
auto [stage, clusterId] = tt::getStageCluster(loadOp);
@@ -585,21 +584,28 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) {
585584
return barrierAlloc;
586585
}
587586

587+
struct StageGroup {
588+
Value insertIdx;
589+
Value extractIdx;
590+
Value phase;
591+
bool hasTMALoad = false;
592+
};
588593
struct AsyncLoad {
589-
AsyncLoad(Operation *loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {}
590594
Operation *loadOp;
591595
Value alloc;
592596
Value barrier;
593597
Operation *waitOp = nullptr;
594598
int firstUseStage, firstUseCluster;
595599
bool isTMALoad = false;
600+
int numBuffers = 0;
596601
};
597602

598603
// Create barriers and wait ops for the async loads. Barriers may be shared by
599-
// multiple loads is the schedule allows it.
604+
// multiple loads if the schedule allows it.
600605
static void createTMABarrierAndWait(
601-
scf::ForOp &forOp, SmallVector<AsyncLoad> &asyncLoads, Value insertIdx,
602-
Value extractIdx, Value phase, int numBuffers, SmallVector<Value> &barriers,
606+
scf::ForOp &forOp, SmallVector<AsyncLoad> &asyncLoads,
607+
SmallVector<Value> &barriers,
608+
const llvm::MapVector<int, StageGroup> &stageGroups,
603609
const llvm::MapVector<Operation *, LoadInfo> &loadToInfo) {
604610
llvm::SmallDenseMap<Operation *, AsyncLoad *> loadToAsyncLoad;
605611
for (AsyncLoad &asyncLoad : asyncLoads) {
@@ -639,12 +645,15 @@ static void createTMABarrierAndWait(
639645
};
640646
addToGroup(&asyncLoad);
641647
Operation *nextOp = asyncLoad.loadOp->getNextNode();
648+
int numBuffers = asyncLoad.numBuffers;
642649
while (nextOp) {
643650
if (users.count(nextOp) || visited.count(nextOp))
644651
break;
645652
if (isa<tt::ExperimentalDescriptorLoadOp>(nextOp)) {
646653
auto it = loadToAsyncLoad.find(nextOp);
647654
if (it != loadToAsyncLoad.end() && it->second->isTMALoad) {
655+
if (it->second->numBuffers != numBuffers)
656+
break;
648657
if (group.size() > 0 &&
649658
sameStageCluster(group[0]->loadOp, it->second->loadOp))
650659
addToGroup(it->second);
@@ -659,6 +668,8 @@ static void createTMABarrierAndWait(
659668
// load.
660669
for (SmallVector<AsyncLoad *> &group : loadGroups) {
661670
int sizeInBytes = 0;
671+
int numBuffers = group[0]->numBuffers;
672+
const StageGroup &stageGroup = stageGroups.find(numBuffers)->second;
662673
for (AsyncLoad *asyncLoad : group) {
663674
auto tensorTy =
664675
cast<RankedTensorType>(asyncLoad->loadOp->getResult(0).getType());
@@ -682,7 +693,7 @@ static void createTMABarrierAndWait(
682693
builder.setInsertionPoint(group[0]->loadOp);
683694
Value barrier = builder.createWithStage<ttg::MemDescSubviewOp>(
684695
loc, stage, cluster, barrierTy, barrierAlloc,
685-
ArrayRef<Value>({insertIdx}));
696+
ArrayRef<Value>({stageGroup.insertIdx}));
686697
Value pred = builder.createWithStage<arith::ConstantIntOp>(loc, stage,
687698
cluster, 1, 1);
688699
Operation *expect = builder.createWithStage<ttng::BarrierExpectOp>(
@@ -691,10 +702,10 @@ static void createTMABarrierAndWait(
691702
builder.setInsertionPointAfter(group.back()->loadOp);
692703
Value barrierViewWait = builder.createWithStage<ttg::MemDescSubviewOp>(
693704
loc, group[0]->firstUseStage, group[0]->firstUseCluster, barrierTy,
694-
barrierAlloc, ArrayRef<Value>({extractIdx}));
705+
barrierAlloc, ArrayRef<Value>({stageGroup.extractIdx}));
695706
Operation *wait = builder.createWithStage<ttng::WaitBarrierOp>(
696707
loc, group[0]->firstUseStage, group[0]->firstUseCluster,
697-
barrierViewWait, phase);
708+
barrierViewWait, stageGroup.phase);
698709
// Update the async loads info.
699710
for (AsyncLoad *asyncLoad : group) {
700711
asyncLoad->barrier = barrier;
@@ -855,46 +866,47 @@ static SmallVector<Value>
855866
createAsyncOps(scf::ForOp &forOp,
856867
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
857868
SmallVector<Value> &barriers, int numStages) {
858-
// Calculate the number of buffers needed for each load.
859-
// TODO pawel: we could do more fine-grained allocation here and
860-
// allocate only the number of buffers that specific loads need.
861-
// Instead, we allocate the maximum number of buffers needed by any load.
862-
int numBuffers =
863-
llvm::max_element(llvm::make_second_range(loadToInfo), [](auto &lhs,
864-
auto &rhs) {
865-
return lhs.distToUse < rhs.distToUse;
866-
})->distToUse;
867-
bool hasMMAV3 = llvm::any_of(loadToInfo, [](auto &kv) {
868-
return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers;
869-
});
870-
if (hasMMAV3) {
871-
// For MMAv3, we need an extra buffer as this is assumed in the wgmma
872-
// pipelining post-processing.
873-
numBuffers++;
874-
};
875-
876869
llvm::MapVector<Operation *, Value> tmaBufferMapping;
877870
if (failed(allocTMABuffers(forOp, tmaBufferMapping, numStages))) {
878871
llvm_unreachable("TMA pipelining failed");
879872
}
880873

874+
// Each group of loads/allocs with the same number of buffers (and stages)
875+
// will share the indices and barriers.
876+
881877
SmallVector<AsyncLoad> asyncLoads;
882878
SmallVector<Value> allocs;
883-
bool hasTMALoad = false;
879+
llvm::MapVector<int, StageGroup> stageGroups;
880+
884881
for (auto &[loadOp, info] : loadToInfo) {
882+
AsyncLoad asyncLoad = {.loadOp = loadOp};
883+
bool isTMALoad = false;
884+
int numBuffers = info.distToUse;
885+
// For MMAv3, we need an extra buffer as this is assumed in the wgmma
886+
// pipelining post-processing.
887+
if (info.isMMAv3Shared || info.isMMAv3Registers) {
888+
++numBuffers;
889+
}
890+
if (isa<tt::ExperimentalDescriptorLoadOp>(loadOp)) {
891+
isTMALoad = true;
892+
asyncLoad.isTMALoad = isTMALoad;
893+
}
885894
assert(info.sharedEncoding && "LoadOp shared encoding not defined.");
886895
Value alloc = createAlloc(forOp, loadOp, info.sharedEncoding, numBuffers);
887896
assert(alloc && "Failed to create alloc for the async load.");
888897
allocs.push_back(alloc);
889-
asyncLoads.emplace_back(loadOp, alloc);
890-
if (isa<tt::ExperimentalDescriptorLoadOp>(loadOp)) {
891-
hasTMALoad = true;
892-
asyncLoads.back().isTMALoad = true;
893-
}
898+
asyncLoad.alloc = alloc;
899+
894900
auto *firstUse = getFirstUseOfPipelinedLoad(loadOp);
895901
auto [firstUseStage, firstUseCluster] = tt::getStageCluster(firstUse);
896-
asyncLoads.back().firstUseStage = firstUseStage;
897-
asyncLoads.back().firstUseCluster = firstUseCluster;
902+
asyncLoad.firstUseStage = firstUseStage;
903+
asyncLoad.firstUseCluster = firstUseCluster;
904+
asyncLoad.numBuffers = numBuffers;
905+
stageGroups.insert({numBuffers, {}});
906+
if (isTMALoad) {
907+
stageGroups[numBuffers].hasTMALoad = true;
908+
}
909+
asyncLoads.push_back(asyncLoad);
898910
}
899911

900912
IRRewriter builder(forOp.getContext());
@@ -908,41 +920,34 @@ createAsyncOps(scf::ForOp &forOp,
908920
Value minusOne = builder.create<arith::ConstantIntOp>(loc, -1, 32);
909921
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
910922
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32);
911-
Value insertIdx = minusOne;
912-
Value extractIdx = minusOne;
913-
Value phase = Value();
914-
Value numBuffersVal =
915-
builder.create<arith::ConstantIntOp>(loc, numBuffers, 32);
916923
SmallVector<Value> newOperands;
917-
newOperands.push_back(insertIdx);
918-
newOperands.push_back(extractIdx);
919-
if (hasTMALoad) {
920-
// A single barrier arrival sequence is a "phase" and two phases can
921-
// overlap, provided the phases are differentiated with an alternating
922-
// boolean value.
923-
phase = builder.create<arith::ConstantIntOp>(loc, 0, 32);
924-
newOperands.push_back(phase);
924+
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
925+
for (auto [_, stageGroup] : stageGroups) {
926+
newOperands.push_back(minusOne); // insertIdx
927+
newOperands.push_back(minusOne); // extractIdx
928+
if (stageGroup.hasTMALoad) {
929+
// A single barrier arrival sequence is a "phase" and two phases can
930+
// overlap, provided the phases are differentiated with an alternating
931+
// boolean value.
932+
newOperands.push_back(zero); // phase
933+
}
925934
}
926935
// Also create one counter per TMA buffer. This allows the descriptors to be
927936
// updated independently without needing to write duplicate of existing tma
928937
// descriptors.
938+
unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size();
929939
for (int i = 0; i < tmaBufferMapping.size(); ++i) {
930940
newOperands.push_back(zero);
931941
}
932942

933-
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
934943
// Patch the loop to add the new loop carried dependencies.
935944
scf::ForOp newForOp =
936945
replaceForOpWithNewSignature(builder, forOp, newOperands);
937946
forOp.erase();
938947
forOp = newForOp;
939-
insertIdx = newForOp.getBody()->getArgument(newOperandIndex);
940-
extractIdx = newForOp.getBody()->getArgument(newOperandIndex + 1);
941-
if (phase) {
942-
phase = newForOp.getBody()->getArgument(newOperandIndex + 2);
943-
}
948+
944949
auto tmaCounters = ArrayRef<BlockArgument>(newForOp.getBody()->getArguments())
945-
.slice(newOperandIndex + (phase ? 3 : 2));
950+
.slice(tmaCounterArgsStartIdx);
946951

947952
// Update yield op with temporary yield values
948953
auto forYield = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
@@ -956,44 +961,70 @@ createAsyncOps(scf::ForOp &forOp,
956961
}
957962
tmaBufferMapping.clear();
958963

959-
// FIXME: loads can be in different (stage, cluster)
960-
// Create two counters for the insert and extract indices to avoid creating
961-
// long liverange.
962-
builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
963-
insertIdx = builder.create<arith::AddIOp>(loc, insertIdx, one);
964-
Value cndIns = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
965-
insertIdx, numBuffersVal);
966-
insertIdx = builder.create<arith::SelectOp>(loc, cndIns, insertIdx, zero);
967-
968-
extractIdx = builder.create<arith::AddIOp>(loc, extractIdx, one);
969-
Value cndExt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
970-
extractIdx, numBuffersVal);
971-
extractIdx = builder.create<arith::SelectOp>(loc, cndExt, extractIdx, zero);
972-
if (phase) {
973-
Value nextPhase = builder.create<arith::XOrIOp>(loc, phase, one);
974-
phase = builder.create<arith::SelectOp>(loc, cndExt, phase, nextPhase);
964+
builder.setInsertionPoint(forOp);
965+
loc = forOp.getLoc();
966+
int argIdx = newOperandIndex;
967+
for (auto &[numBuffers, stageGroup] : stageGroups) {
968+
Value insertIdx = newForOp.getBody()->getArgument(argIdx);
969+
argIdx++;
970+
Value extractIdx = newForOp.getBody()->getArgument(argIdx);
971+
argIdx++;
972+
Value phase = nullptr;
973+
if (stageGroup.hasTMALoad) {
974+
phase = newForOp.getBody()->getArgument(argIdx);
975+
argIdx++;
976+
}
977+
978+
// Create two counters for the insert and extract indices to avoid creating
979+
// long liverange.
980+
builder.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
981+
982+
Value numBuffersVal =
983+
builder.create<arith::ConstantIntOp>(loc, numBuffers, 32);
984+
insertIdx = builder.create<arith::AddIOp>(loc, insertIdx, one);
985+
Value cndIns = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
986+
insertIdx, numBuffersVal);
987+
insertIdx = builder.create<arith::SelectOp>(loc, cndIns, insertIdx, zero);
988+
stageGroup.insertIdx = insertIdx;
989+
990+
extractIdx = builder.create<arith::AddIOp>(loc, extractIdx, one);
991+
// Duplicate the constant to keep it from being carried across loops.
992+
numBuffersVal = builder.create<arith::ConstantIntOp>(loc, numBuffers, 32);
993+
Value cndExt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
994+
extractIdx, numBuffersVal);
995+
extractIdx = builder.create<arith::SelectOp>(loc, cndExt, extractIdx, zero);
996+
stageGroup.extractIdx = extractIdx;
997+
if (phase) {
998+
Value nextPhase = builder.create<arith::XOrIOp>(loc, phase, one);
999+
phase = builder.create<arith::SelectOp>(loc, cndExt, phase, nextPhase);
1000+
stageGroup.phase = phase;
1001+
}
9751002
}
976-
createTMABarrierAndWait(forOp, asyncLoads, insertIdx, extractIdx, phase,
977-
numBuffers, barriers, loadToInfo);
1003+
createTMABarrierAndWait(forOp, asyncLoads, barriers, stageGroups, loadToInfo);
9781004

9791005
auto [_, maxClusterId] = tt::getMinMaxCluster(forOp);
9801006
for (AsyncLoad &asyncLoad : asyncLoads) {
1007+
auto [insertIdx, extractIdx, phase, _] = stageGroups[asyncLoad.numBuffers];
9811008
if (auto loadOp = dyn_cast<tt::LoadOp>(asyncLoad.loadOp)) {
9821009
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
983-
loadToInfo, numStages, maxClusterId);
1010+
loadToInfo, maxClusterId);
9841011
} else {
9851012
auto descLoad = cast<tt::ExperimentalDescriptorLoadOp>(asyncLoad.loadOp);
9861013
createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx,
9871014
extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase,
988-
loadToInfo, numStages);
1015+
loadToInfo);
9891016
}
9901017
}
991-
// Patch the yield with the updated counters.
992-
forYield.setOperand(newOperandIndex + -1, insertIdx);
993-
forYield.setOperand(newOperandIndex + 0, extractIdx);
994-
if (phase) {
995-
forYield.setOperand(newOperandIndex + 1, phase);
1018+
// Patch the yield with the updated counters. Subtract to account for the loop
1019+
// counter.
1020+
argIdx = newOperandIndex - 1;
1021+
for (auto &[numBuffers, stageGroup] : stageGroups) {
1022+
forYield.setOperand(argIdx++, stageGroup.insertIdx);
1023+
forYield.setOperand(argIdx++, stageGroup.extractIdx);
1024+
if (stageGroup.phase)
1025+
forYield.setOperand(argIdx++, stageGroup.phase);
9961026
}
1027+
assert(argIdx + 1 == tmaCounterArgsStartIdx);
9971028

9981029
tt::CoarseSchedule coarseSchedule(numStages);
9991030
coarseSchedule.deSerialize(forOp);

0 commit comments

Comments
 (0)