Skip to content

Commit b08a27e

Browse files
authored
[NVWS] Add pass to insert aref for TMA load (#7581)
The next PR will update `LowerAref` so that one half of `LoadMMASpecialization`, TMA load pipelining and lowering, can be replaced by the aref-based flow. The new pass is called `InsertAref` but for now it only handles TMA and SMEM aref. After triton-lang/triton#7561 is merged, we can consider folding value aref insertion logic into `InsertAref` and retire `RewritePartitionDependencies`. The plan for TMEM is TBD. cc @3gx
1 parent a1f42ef commit b08a27e

File tree

24 files changed

+1065
-252
lines changed

24 files changed

+1065
-252
lines changed

include/triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TRITONGPU_WARPSPECIALIZATION_PARTITIONBUILDER_H
2-
#define TRITONGPU_WARPSPECIALIZATION_PARTITIONBUILDER_H
1+
#ifndef TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H
2+
#define TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H
33

44
#include "mlir/IR/ImplicitLocOpBuilder.h"
55

@@ -33,4 +33,4 @@ StageCluster getStageCluster(Operation *op);
3333

3434
} // namespace mlir::triton::gpu
3535

36-
#endif // TRITONGPU_WARPSPECIALIZATION_PARTITIONBUILDER_H
36+
#endif // TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,14 @@ void combineRedundantWaitOps(
132132
llvm::SmallSetVector<gpu::AsyncWaitOp, 8> &waitOps);
133133

134134
// Get the type of the view of a multi-buffered tensor value.
135-
gpu::MemDescType getBufferViewType(gpu::MemDescType allocTy);
135+
gpu::MemDescType getBufferViewType(gpu::MemDescType allocTy,
136+
bool mutableMemory = true);
137+
138+
// Get a mutable, multi-buffered version of the given memdesc type, with
139+
// multiplicity "depth".
140+
gpu::MemDescType getMultiBufferedType(gpu::MemDescType memDescType,
141+
int32_t depth);
142+
136143
// Get a generic shared encoding for a tensor.
137144
gpu::SharedEncodingTrait getSharedEncoding(RankedTensorType ty);
138145
// Get a shared encoding for a tensor based on its uses.
@@ -157,6 +164,22 @@ Value createIncrementModulo(OpBuilder &builder, Location loc, Value counter,
157164

158165
scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule);
159166

167+
DenseSet<Operation *>
168+
getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp,
169+
std::function<bool(Operation *)> filter = nullptr);
170+
171+
// Return the "first" op in terms of the stage and cluser ordering
172+
Operation *
173+
getFirstUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
174+
CoarseSchedule &schedule,
175+
std::function<bool(Operation *)> filterUse = nullptr);
176+
177+
// Return the "last" op in terms of the stage and cluser ordering
178+
Operation *
179+
getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
180+
CoarseSchedule &schedule,
181+
std::function<bool(Operation *)> filterUse = nullptr);
182+
160183
} // namespace triton
161184
} // namespace mlir
162185

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,11 @@ namespace mlir::triton {
255255
/// Replace all uses of `oldUse` with `val` and propagate the type if needed.
256256
/// This is useful when we need to change a memory descriptor from immutable to
257257
/// mutable.
258-
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
259-
Value val);
258+
/// The callback is invoked for each pair of an old and a cloned memdesc op
259+
/// as the type is propagated.
260+
void replaceUsesAndPropagateType(
261+
OpBuilder &builder, Operation *oldUse, Value val,
262+
std::function<void(Operation *, Operation *)> callback = nullptr);
260263

261264
/// Replace all uses of `old` with a local load from `alloc` unless the use is a
262265
/// `ttg.local_alloc` with a matching shared encoding, in which case the shared

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -48,66 +48,6 @@ int getSelfLatencyFromAttr(Operation *op) {
4848
return val;
4949
}
5050

51-
DenseSet<Operation *>
52-
getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp,
53-
std::function<bool(Operation *)> filter = nullptr) {
54-
DenseSet<Operation *> topLevelUsers;
55-
SmallVector<OpOperand *> q;
56-
for (auto &use : op->getUses())
57-
q.push_back(&use);
58-
while (!q.empty()) {
59-
auto use = q.pop_back_val();
60-
auto yieldOp = dyn_cast<scf::YieldOp>(use->getOwner());
61-
if (yieldOp && yieldOp->getParentOp() == forOp) {
62-
for (auto &use :
63-
forOp.getRegionIterArgs()[use->getOperandNumber()].getUses())
64-
q.push_back(&use);
65-
continue;
66-
}
67-
// Don't count view operations as uses. Follow them through to their
68-
// users.
69-
if (use->getOwner()->hasTrait<OpTrait::MemDescViewTrait>()) {
70-
for (auto &use : use->getOwner()->getUses())
71-
q.push_back(&use);
72-
continue;
73-
}
74-
if (filter && !filter(use->getOwner()))
75-
continue;
76-
Operation *topLevelUser =
77-
forOp.getBody()->findAncestorOpInBlock(*use->getOwner());
78-
topLevelUsers.insert(topLevelUser);
79-
}
80-
return topLevelUsers;
81-
}
82-
83-
Operation *getFirstUseOfPipelinedOp(SmallVector<Operation *> ops,
84-
scf::ForOp forOp,
85-
CoarseSchedule &schedule) {
86-
Operation *firstUser = nullptr;
87-
DenseSet<Operation *> topLevelUsers;
88-
for (Operation *op : ops) {
89-
auto users = getTopLevelUsersInLoop(op, forOp);
90-
topLevelUsers.insert(users.begin(), users.end());
91-
}
92-
for (Operation *topLevelUser : topLevelUsers) {
93-
assert(schedule.count(topLevelUser) && "op user not found in the schedule");
94-
auto [_useStage, _useCluster] = schedule[topLevelUser];
95-
if (!firstUser) {
96-
firstUser = topLevelUser;
97-
} else {
98-
auto [_firstUserStage, _firstUserCluster] = schedule[firstUser];
99-
if (_useStage < _firstUserStage ||
100-
(_useStage == _firstUserStage &&
101-
schedule.clusters.isBefore(_useCluster, _firstUserCluster)) ||
102-
(_useStage == _firstUserStage && _useCluster == _firstUserCluster &&
103-
topLevelUser->isBeforeInBlock(firstUser))) {
104-
firstUser = topLevelUser;
105-
}
106-
}
107-
}
108-
return firstUser;
109-
}
110-
11151
// Check if the load can be pipelined entirely in shared memory,
11252
// or if we need to load to registers.
11353
bool mustLoadToRegisters(Operation *op) {
@@ -142,7 +82,8 @@ int getDefUseStageDiff(Operation *op, scf::ForOp forOp,
14282
assert(schedule.count(op) && "Op not found in the schedule");
14383
int defStage = schedule[op].first;
14484
std::optional<int> useStage;
145-
DenseSet<Operation *> topLevelUsers = getTopLevelUsersInLoop(op, forOp);
85+
DenseSet<Operation *> topLevelUsers =
86+
triton::getTopLevelUsersInLoop(op, forOp);
14687
// Special case for loads used by local_alloc:
14788
// we must consider the uses of the local_alloc, as it may be removed and its
14889
// uses will become direct uses of the async load.
@@ -152,7 +93,8 @@ int getDefUseStageDiff(Operation *op, scf::ForOp forOp,
15293
DenseSet<Operation *> allocUsers;
15394
for (Operation *topLevelUser : topLevelUsers) {
15495
if (auto localAlloc = dyn_cast<ttg::LocalAllocOp>(topLevelUser)) {
155-
DenseSet<Operation *> users = getTopLevelUsersInLoop(localAlloc, forOp);
96+
DenseSet<Operation *> users =
97+
triton::getTopLevelUsersInLoop(localAlloc, forOp);
15698
allocUsers.insert(users.begin(), users.end());
15799
}
158100
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,15 +561,25 @@ void mlir::triton::combineRedundantWaitOps(
561561
}
562562
}
563563

564-
ttg::MemDescType mlir::triton::getBufferViewType(ttg::MemDescType allocTy) {
565-
Attribute sharedMemorySpace =
566-
ttg::SharedMemorySpaceAttr::get(allocTy.getContext());
564+
ttg::MemDescType mlir::triton::getBufferViewType(ttg::MemDescType allocTy,
565+
bool mutableMemory) {
567566
return ttg::MemDescType::get(allocTy.getShape().drop_front(),
568567
allocTy.getElementType(), allocTy.getEncoding(),
569-
sharedMemorySpace, /*mutableMemory=*/true,
568+
allocTy.getMemorySpace(), mutableMemory,
570569
/*allocShape=*/allocTy.getAllocShape());
571570
}
572571

572+
ttg::MemDescType
573+
mlir::triton::getMultiBufferedType(ttg::MemDescType memDescType,
574+
int32_t depth) {
575+
auto shape = memDescType.getShape();
576+
SmallVector<int64_t> bufferShape(shape.begin(), shape.end());
577+
bufferShape.insert(bufferShape.begin(), depth);
578+
return ttg::MemDescType::get(
579+
bufferShape, memDescType.getElementType(), memDescType.getEncoding(),
580+
memDescType.getMemorySpace(), /*mutableMemory*/ true);
581+
}
582+
573583
ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) {
574584
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
575585
auto order = ttg::getOrder(ty);
@@ -810,3 +820,94 @@ scf::ForOp triton::lowerTMADescriptors(scf::ForOp forOp,
810820
}
811821
return forOp;
812822
}
823+
824+
DenseSet<Operation *>
825+
triton::getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp,
826+
std::function<bool(Operation *)> filter) {
827+
DenseSet<Operation *> topLevelUsers;
828+
SmallVector<OpOperand *> q;
829+
for (auto &use : op->getUses())
830+
q.push_back(&use);
831+
while (!q.empty()) {
832+
auto use = q.pop_back_val();
833+
auto yieldOp = dyn_cast<scf::YieldOp>(use->getOwner());
834+
if (yieldOp && yieldOp->getParentOp() == forOp) {
835+
for (auto &use :
836+
forOp.getRegionIterArgs()[use->getOperandNumber()].getUses())
837+
q.push_back(&use);
838+
continue;
839+
}
840+
// Don't count view operations as uses. Follow them through to their
841+
// users.
842+
if (use->getOwner()->hasTrait<OpTrait::MemDescViewTrait>()) {
843+
for (auto &use : use->getOwner()->getUses())
844+
q.push_back(&use);
845+
continue;
846+
}
847+
if (filter && !filter(use->getOwner()))
848+
continue;
849+
Operation *topLevelUser =
850+
forOp.getBody()->findAncestorOpInBlock(*use->getOwner());
851+
topLevelUsers.insert(topLevelUser);
852+
}
853+
return topLevelUsers;
854+
}
855+
856+
// Helper function that finds an operation based on a comparison predicate
857+
static Operation *getUseOfPipelinedOp(
858+
ArrayRef<Operation *> ops, scf::ForOp forOp,
859+
triton::CoarseSchedule &schedule,
860+
std::function<bool(Operation *)> filterUse,
861+
std::function<bool(Operation *, Operation *)> shouldPrefer) {
862+
DenseSet<Operation *> topLevelUsers;
863+
Operation *selectedUser = nullptr;
864+
for (Operation *op : ops) {
865+
auto users = triton::getTopLevelUsersInLoop(op, forOp, filterUse);
866+
topLevelUsers.insert(users.begin(), users.end());
867+
}
868+
for (Operation *topLevelUser : topLevelUsers) {
869+
assert(schedule.count(topLevelUser) && "op user not found in the schedule");
870+
if (!selectedUser || shouldPrefer(topLevelUser, selectedUser)) {
871+
selectedUser = topLevelUser;
872+
}
873+
}
874+
return selectedUser;
875+
}
876+
877+
Operation *
878+
triton::getFirstUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
879+
triton::CoarseSchedule &schedule,
880+
std::function<bool(Operation *)> filterUse) {
881+
return getUseOfPipelinedOp(
882+
ops, forOp, schedule, filterUse,
883+
[&](Operation *candidate, Operation *current) {
884+
auto [candidateStage, candidateCluster] = schedule[candidate];
885+
auto [currentStage, currentCluster] = schedule[current];
886+
887+
return candidateStage < currentStage ||
888+
(candidateStage == currentStage &&
889+
schedule.clusters.isBefore(candidateCluster, currentCluster)) ||
890+
(candidateStage == currentStage &&
891+
candidateCluster == currentCluster &&
892+
candidate->isBeforeInBlock(current));
893+
});
894+
}
895+
896+
Operation *
897+
triton::getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
898+
triton::CoarseSchedule &schedule,
899+
std::function<bool(Operation *)> filterUse) {
900+
return getUseOfPipelinedOp(
901+
ops, forOp, schedule, filterUse,
902+
[&](Operation *candidate, Operation *current) {
903+
auto [candidateStage, candidateCluster] = schedule[candidate];
904+
auto [currentStage, currentCluster] = schedule[current];
905+
906+
return candidateStage > currentStage ||
907+
(candidateStage == currentStage &&
908+
schedule.clusters.isBefore(currentCluster, candidateCluster)) ||
909+
(candidateStage == currentStage &&
910+
candidateCluster == currentCluster &&
911+
current->isBeforeInBlock(candidate));
912+
});
913+
}

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,8 +1463,10 @@ void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices) {
14631463
} // namespace mlir
14641464

14651465
namespace mlir::triton {
1466-
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
1467-
Value val) {
1466+
1467+
void replaceUsesAndPropagateType(
1468+
OpBuilder &builder, Operation *oldUse, Value val,
1469+
std::function<void(Operation *, Operation *)> callback) {
14681470
OpBuilder::InsertionGuard guard(builder);
14691471
SmallVector<Operation *> opsToDelete;
14701472
SmallVector<OpOperand *> operandsToReplace;
@@ -1515,7 +1517,10 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
15151517
assert(newVal && "unhandled memdesc view");
15161518
newVal.getDefiningOp()->setAttrs(user->getAttrs());
15171519
replaceUsesAndPropagateType(builder, user, newVal);
1518-
opsToDelete.push_back(use.getOwner());
1520+
opsToDelete.push_back(user);
1521+
if (callback) {
1522+
callback(user, newVal.getDefiningOp());
1523+
}
15191524
}
15201525

15211526
// Perform late replacement.
@@ -1530,7 +1535,6 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
15301535
wait.replaceAllUsesWith(newWait.getResults());
15311536
wait.erase();
15321537
} else {
1533-
Operation *op = operand->getOwner();
15341538
operand->set(val);
15351539
}
15361540
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,6 @@ getPartitionScheme(scf::ForOp loop, const WarpSchedule &schedule) {
8585
// Utilities
8686
//===----------------------------------------------------------------------===//
8787

88-
static void replaceAllUsesDominatedBy(Operation *domOp, Value newValue,
89-
Value oldValue, DominanceInfo &domInfo) {
90-
if (newValue == oldValue)
91-
return;
92-
oldValue.replaceUsesWithIf(newValue, [&](OpOperand &use) {
93-
return domInfo.properlyDominates(domOp, use.getOwner());
94-
});
95-
}
96-
9788
static std::pair<Value, Value> postIncrementModulo(ImplicitLocOpBuilder &b,
9889
Value index, Value phase,
9990
unsigned numStages) {

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,14 @@ struct AsyncRef {
6363
StageCluster srcStageCluster) {
6464
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
6565
auto enterOp = b.createInto<triton::nvws::ArefPutEnterOp>(
66-
partition, srcStageCluster, viewType, aref, zero, zero);
66+
partition, srcStageCluster, viewType, tokenType, aref, zero, zero);
67+
auto token = enterOp.getToken();
6768

68-
auto exitOp = [this, &partition, srcStageCluster](PartitionBuilder &b) {
69+
auto exitOp = [this, &partition, srcStageCluster,
70+
token](PartitionBuilder &b) {
6971
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
7072
auto exitOp = b.createInto<triton::nvws::ArefPutExitOp>(
71-
partition, srcStageCluster, aref, zero,
73+
partition, srcStageCluster, aref, token, zero,
7274
b.getArrayAttr(SmallVector<Attribute>{triton::nvws::AsyncOpAttr::get(
7375
aref.getContext(), triton::nvws::AsyncOp::NONE)}));
7476
};
@@ -79,12 +81,14 @@ struct AsyncRef {
7981
StageCluster srcStageCluster) {
8082
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
8183
auto enterOp = b.createInto<triton::nvws::ArefGetEnterOp>(
82-
partition, srcStageCluster, viewType, aref, zero, zero);
84+
partition, srcStageCluster, viewType, tokenType, aref, zero, zero);
85+
auto token = enterOp.getToken();
8386

84-
auto exitOp = [this, &partition, srcStageCluster](PartitionBuilder &b) {
87+
auto exitOp = [this, &partition, srcStageCluster,
88+
token](PartitionBuilder &b) {
8589
auto zero = b.create<arith::ConstantOp>(b.getI32IntegerAttr(0));
8690
auto exitOp = b.createInto<triton::nvws::ArefGetExitOp>(
87-
partition, srcStageCluster, aref, zero,
91+
partition, srcStageCluster, aref, token, zero,
8892
b.getArrayAttr(SmallVector<Attribute>{triton::nvws::AsyncOpAttr::get(
8993
aref.getContext(), triton::nvws::AsyncOp::NONE)}));
9094
};
@@ -93,6 +97,7 @@ struct AsyncRef {
9397

9498
Value aref;
9599
MemDescType viewType;
100+
AsyncTokenType tokenType;
96101
};
97102

98103
//===----------------------------------------------------------------------===//
@@ -137,7 +142,8 @@ AsyncRef DependencyRewriter::allocateAsyncValue(RankedTensorType tensorType,
137142

138143
endBuilder.create<nvws::ArefDestroyOp>(aref);
139144

140-
return AsyncRef{aref, getBufferViewType(allocType)};
145+
return AsyncRef{aref, getBufferViewType(allocType),
146+
b.getType<AsyncTokenType>()};
141147
}
142148

143149
LogicalResult DependencyRewriter::run() {

0 commit comments

Comments
 (0)