Skip to content

Commit 61a327a

Browse files
committed
Merge commit 'eb73b0373a7fb4cd2e563f68e3488a96525562eb'
2 parents d8f73fe + eb73b03 commit 61a327a

File tree

19 files changed

+319
-83
lines changed

19 files changed

+319
-83
lines changed

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,19 @@ class MMAv5PipelineableOperandsHelper {
3838
: mmaOp(mmaOp), forOp(forOp), isLoadToBePipelined(isLoadToBePipelined) {
3939
run();
4040
}
41+
4142
bool isPipelineable = false;
4243
// If true, the existing operand loads are all been found and their
4344
// pipelineability has been determined.
4445
bool isOperandsStateDetermined = false;
45-
SmallVector<Operation *> unpipelineableOperandLoads;
46+
SmallVector<Operation *> unpipelineableOperandDefs;
4647

4748
private:
4849
MMAv5OpInterface mmaOp;
4950
scf::ForOp forOp;
5051
std::function<bool(Operation *)> isLoadToBePipelined;
51-
bool comesFromLoadOrOutsideLoop(Value v, Operation *&foundLoad);
5252
void run();
53+
bool isOperandPipelineable(Value v, Operation *&foundDef);
5354
};
5455

5556
//===----------------------------------------------------------------------===//

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,9 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
636636

637637
// Make sure all ops have attributes.
638638
for (Operation &op : forOp.getBody()->without_terminator()) {
639+
if (!schedule.count(&op)) {
640+
op.emitError() << "op not found in the schedule";
641+
}
639642
assert(schedule.count(&op) && "op not found in the schedule");
640643
}
641644
return forOp;
@@ -796,6 +799,41 @@ getTmemUseStageBoundOps(ttng::TMEMAllocOp alloc, scf::ForOp forOp,
796799
return bounds;
797800
}
798801

802+
Operation *hoistBufferOutOfLoop(scf::ForOp forOp, Operation *op,
803+
CoarseSchedule &schedule) {
804+
Operation *newStore = nullptr;
805+
if (!isa<ttng::TMEMAllocOp, ttg::LocalAllocOp>(op))
806+
return nullptr;
807+
// If the alloc is already out of the loop, there is nothing to do.
808+
if (!forOp->isAncestor(op))
809+
return nullptr;
810+
OpBuilderForStage builder(op->getLoc(), forOp, schedule);
811+
auto allocType = dyn_cast<MemDescType>(op->getResult(0).getType());
812+
auto newType = triton::gpu::MemDescType::get(
813+
allocType.getShape(), allocType.getElementType(), allocType.getEncoding(),
814+
allocType.getMemorySpace(),
815+
/*mutableMemory=*/true);
816+
auto newAlloc = builder.clone(*op);
817+
newAlloc->getResult(0).setType(newType);
818+
builder.setStageCluster(schedule[op]);
819+
if (auto tmemAlloc = dyn_cast<ttng::TMEMAllocOp>(newAlloc)) {
820+
tmemAlloc.getSrcMutable().clear();
821+
builder.setInsertionPointAfter(op);
822+
Value trueVal = builder.create<arith::ConstantIntOp>(1, 1);
823+
newStore = builder.create<ttng::TMEMStoreOp>(tmemAlloc.getResult(),
824+
op->getOperand(0), trueVal);
825+
} else {
826+
auto localAlloc = cast<ttg::LocalAllocOp>(newAlloc);
827+
localAlloc.getSrcMutable().clear();
828+
builder.setInsertionPointAfter(op);
829+
newStore = builder.create<ttg::LocalStoreOp>(op->getOperand(0),
830+
localAlloc.getResult());
831+
}
832+
op->replaceAllUsesWith(newAlloc);
833+
op->erase();
834+
return newStore;
835+
}
836+
799837
void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
800838
ttng::MMAv5OpInterface mma, int mmaSelfLatency,
801839
ttng::TMEMAllocOp alloc, int phaseArgIdx,
@@ -818,13 +856,24 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
818856

819857
ttng::MMAv5PipelineableOperandsHelper mmaPipeHelper(mma, forOp,
820858
isLoadToBePipelined);
859+
860+
SmallVector<Operation *> updatedDefs;
861+
for (auto def : mmaPipeHelper.unpipelineableOperandDefs) {
862+
auto newStore = hoistBufferOutOfLoop(forOp, def, schedule);
863+
if (newStore) {
864+
updatedDefs.push_back(newStore);
865+
} else {
866+
updatedDefs.push_back(def);
867+
}
868+
}
869+
821870
if (!mmaPipeHelper.isPipelineable &&
822871
mmaPipeHelper.isOperandsStateDetermined) {
823872
// If the operands are not pipelineable, we need to insert a sync point
824873
// before the earliest operand load
825-
for (auto load : mmaPipeHelper.unpipelineableOperandLoads) {
826-
if (!latestSyncPoint || schedule.isOpBefore(load, *latestSyncPoint)) {
827-
latestSyncPoint = load;
874+
for (auto def : updatedDefs) {
875+
if (!latestSyncPoint || schedule.isOpBefore(def, *latestSyncPoint)) {
876+
latestSyncPoint = def;
828877
}
829878
}
830879
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ namespace ttng = mlir::triton::nvidia_gpu;
1414
// MMA Pipeline Analysis
1515
//===----------------------------------------------------------------------===//
1616

17-
bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
18-
Value v, Operation *&foundLoad) {
17+
bool ttng::MMAv5PipelineableOperandsHelper::isOperandPipelineable(
18+
Value v, Operation *&foundDef) {
1919
if (forOp.isDefinedOutsideOfLoop(v)) {
2020
return true;
2121
}
@@ -25,14 +25,16 @@ bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
2525
while (isa<ttg::MemDescTransOp, ttg::MemDescReshapeOp>(v.getDefiningOp())) {
2626
v = v.getDefiningOp()->getOperand(0);
2727
}
28-
if (auto tmemAlloc = dyn_cast<ttng::TMEMAllocOp>(v.getDefiningOp())) {
29-
foundLoad = tmemAlloc;
28+
if (isa<ttg::LocalStoreOp, ttng::TMEMStoreOp, ttng::TMEMAllocOp>(
29+
v.getDefiningOp())) {
30+
foundDef = v.getDefiningOp();
3031
return false;
3132
}
3233
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(v.getDefiningOp());
3334
if (!localAlloc) {
3435
return false;
3536
}
37+
foundDef = localAlloc;
3638
if (!localAlloc.getSrc()) {
3739
return false;
3840
}
@@ -44,17 +46,18 @@ bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
4446
localAllocSrc)) {
4547
return false;
4648
}
47-
foundLoad = localAllocSrc;
48-
if (!isLoadToBePipelined(foundLoad)) {
49+
foundDef = localAllocSrc;
50+
if (!isLoadToBePipelined(localAllocSrc)) {
4951
return false;
5052
}
51-
if (canBeAsyncLoad(foundLoad)) {
53+
if (canBeAsyncLoad(localAllocSrc)) {
5254
return true;
5355
}
5456
return false;
5557
}
5658

5759
void ttng::MMAv5PipelineableOperandsHelper::run() {
60+
unpipelineableOperandDefs.clear();
5861
isOperandsStateDetermined = true;
5962
// Accumulator alloc must be outside the loop.
6063
auto tmemAlloc = mmaOp.getAccumulator().getDefiningOp<ttng::TMEMAllocOp>();
@@ -65,17 +68,17 @@ void ttng::MMAv5PipelineableOperandsHelper::run() {
6568
return;
6669
}
6770
if (auto dotOp = dyn_cast<tt::DotOpInterface>(mmaOp.getOperation())) {
68-
Operation *foundLoad = nullptr;
69-
if (!comesFromLoadOrOutsideLoop(dotOp.getA(), foundLoad)) {
70-
if (foundLoad) {
71-
unpipelineableOperandLoads.push_back(foundLoad);
71+
Operation *foundDef = nullptr;
72+
if (!isOperandPipelineable(dotOp.getA(), foundDef)) {
73+
if (foundDef) {
74+
unpipelineableOperandDefs.push_back(foundDef);
7275
} else {
7376
isOperandsStateDetermined = false;
7477
}
7578
}
76-
if (!comesFromLoadOrOutsideLoop(dotOp.getB(), foundLoad)) {
77-
if (foundLoad) {
78-
unpipelineableOperandLoads.push_back(foundLoad);
79+
if (!isOperandPipelineable(dotOp.getB(), foundDef)) {
80+
if (foundDef) {
81+
unpipelineableOperandDefs.push_back(foundDef);
7982
} else {
8083
isOperandsStateDetermined = false;
8184
}
@@ -95,24 +98,24 @@ void ttng::MMAv5PipelineableOperandsHelper::run() {
9598
isOperandsStateDetermined = false;
9699
return;
97100
}
98-
Operation *foundLoad = nullptr;
99-
if (!comesFromLoadOrOutsideLoop(scaledOp.getAScale(), foundLoad)) {
100-
if (foundLoad) {
101-
unpipelineableOperandLoads.push_back(foundLoad);
101+
Operation *foundDef = nullptr;
102+
if (!isOperandPipelineable(scaledOp.getAScale(), foundDef)) {
103+
if (foundDef) {
104+
unpipelineableOperandDefs.push_back(foundDef);
102105
} else {
103106
isOperandsStateDetermined = false;
104107
}
105108
}
106-
if (!comesFromLoadOrOutsideLoop(scaledOp.getBScale(), foundLoad)) {
107-
if (foundLoad) {
108-
unpipelineableOperandLoads.push_back(foundLoad);
109+
if (!isOperandPipelineable(scaledOp.getBScale(), foundDef)) {
110+
if (foundDef) {
111+
unpipelineableOperandDefs.push_back(foundDef);
109112
} else {
110113
isOperandsStateDetermined = false;
111114
}
112115
}
113116
}
114117
isPipelineable =
115-
isOperandsStateDetermined && unpipelineableOperandLoads.empty();
118+
isOperandsStateDetermined && unpipelineableOperandDefs.empty();
116119
}
117120

118121
bool ttng::hasAccReadModifyWrite(ttng::MMAv5OpInterface mma, scf::ForOp forOp) {

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,19 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
13991399
return;
14001400
LogicalResult result = getRematerializableSlice(
14011401
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
1402+
1403+
// If a value is already assigned to a _different_ layout,
1404+
// we cannot propagate past this op (as it would conflict with
1405+
// an already-assigned layout).
1406+
for (auto [val, enc] : tempLayout) {
1407+
auto preexistingLayout = layout.find(val);
1408+
if (preexistingLayout != layout.end() &&
1409+
preexistingLayout->second != enc) {
1410+
result = failure();
1411+
break;
1412+
}
1413+
}
1414+
14021415
// If we can rematerialize the rest of the ext slice we can ignore this
14031416
// ext as it won't need a convert.
14041417
if (result.succeeded()) {

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ struct FenceInsertionPass
3939
mod.walk([&](DotOpInterface dotOp) {
4040
Value a = dotOp.getA();
4141
Value b = dotOp.getB();
42-
bool aDependsOnShared = dependOnCopyRegToShared(a);
43-
bool bDependsOnShared = dependOnCopyRegToShared(b);
44-
if (!aDependsOnShared && !bDependsOnShared)
42+
SmallVector<Operation *> copyRegToSharedOpsA = findCopyRegToSharedOps(a);
43+
SmallVector<Operation *> copyRegToSharedOpsB = findCopyRegToSharedOps(b);
44+
if (copyRegToSharedOpsA.empty() && copyRegToSharedOpsB.empty())
4545
return WalkResult::advance();
4646

4747
OpBuilder builder(dotOp);
@@ -50,11 +50,13 @@ struct FenceInsertionPass
5050
// If there is all the dependencies are outside of the loop try to hoist
5151
// the fence.
5252
while (auto loopOp = fence->getParentOfType<LoopLikeOpInterface>()) {
53-
if (aDependsOnShared &&
54-
loopOp->isAncestor(a.getParentBlock()->getParentOp()))
53+
if (!copyRegToSharedOpsA.empty() &&
54+
llvm::any_of(copyRegToSharedOpsA,
55+
[&](Operation *op) { return loopOp->isAncestor(op); }))
5556
break;
56-
if (bDependsOnShared &&
57-
loopOp->isAncestor(b.getParentBlock()->getParentOp()))
57+
if (!copyRegToSharedOpsB.empty() &&
58+
llvm::any_of(copyRegToSharedOpsB,
59+
[&](Operation *op) { return loopOp->isAncestor(op); }))
5860
break;
5961
loopOp.moveOutOfLoop(fence);
6062
}
@@ -72,31 +74,47 @@ struct FenceInsertionPass
7274

7375
private:
7476
// Return true if the operand depends on a copy from register to shared.
75-
bool dependOnCopyRegToShared(Value operand) {
77+
SmallVector<Operation *> findCopyRegToSharedOps(Value operand) {
7678
DenseSet<Value> visited;
77-
return dependOnCopyRegToShared(operand, visited);
79+
llvm::SetVector<Operation *> result;
80+
findCopyRegToSharedOps(operand, visited, result);
81+
return result.takeVector();
7882
}
7983

80-
bool dependOnCopyRegToShared(Value operand, DenseSet<Value> &visited) {
84+
void findCopyRegToSharedOps(Value operand, DenseSet<Value> &visited,
85+
llvm::SetVector<Operation *> &result) {
8186
// If the value has already been visited we can safely return false as we
8287
// would early return when true.
8388
if (visited.count(operand))
84-
return false;
89+
return;
8590
visited.insert(operand);
8691
if (!isa<triton::gpu::MemDescType>(operand.getType()))
87-
return false;
92+
return;
8893

8994
auto op = operand.getDefiningOp();
9095
if (op) {
9196
// reach an alloc copying from register, we need a fence.
92-
if (isa<ttg::LocalAllocOp>(op) && cast<ttg::LocalAllocOp>(op).getSrc())
93-
return true;
97+
if (auto localAlloc = dyn_cast<ttg::LocalAllocOp>(op)) {
98+
if (localAlloc.getSrc()) {
99+
result.insert(op);
100+
}
101+
// Check if there are local_store ops that write to that buffer.
102+
for (auto user : localAlloc.getResult().getUsers()) {
103+
while (user->hasOneUse() &&
104+
user->hasTrait<OpTrait::MemDescViewTrait>()) {
105+
user = *user->getUsers().begin();
106+
}
107+
if (isa<ttg::LocalStoreOp>(user)) {
108+
result.insert(user);
109+
return;
110+
}
111+
}
112+
}
94113
// if it is not an alloc, iterate over the operands.
95114
for (auto v : op->getOperands()) {
96-
if (dependOnCopyRegToShared(v))
97-
return true;
115+
findCopyRegToSharedOps(v, visited, result);
98116
}
99-
return false;
117+
return;
100118
}
101119

102120
// reach BlockArgument
@@ -108,22 +126,23 @@ struct FenceInsertionPass
108126
assert(argNum != 0 && "induction var cannot be memdesc type");
109127
--argNum;
110128
// prologue
111-
if (dependOnCopyRegToShared(forOp.getInitArgs()[argNum], visited))
112-
return true;
129+
findCopyRegToSharedOps(forOp.getInitArgs()[argNum], visited, result);
113130
// yield
114131
auto yieldOp = forOp.getBody()->getTerminator();
115132
Value v = yieldOp->getOperand(argNum);
116-
return dependOnCopyRegToShared(v, visited);
133+
findCopyRegToSharedOps(v, visited, result);
134+
return;
117135
}
118136

119137
// look through `ttg.warp_specialize`.
120138
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(argOwner)) {
121-
return dependOnCopyRegToShared(
122-
wsOp.getParentOp().getExplicitCaptures()[argNum]);
139+
findCopyRegToSharedOps(wsOp.getParentOp().getExplicitCaptures()[argNum],
140+
visited, result);
141+
return;
123142
}
124143

125144
// Conservatively return true for other ops
126-
return true;
145+
result.insert(argOwner);
127146
}
128147
};
129148

python/src/gluon_ir.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,11 @@ void init_gluon_ir(py::module &&m) {
299299
self.create<ttng::AsyncTMAScatterOp>(descPtr, xOffsets, yOffset,
300300
src);
301301
})
302+
.def("create_fence_async_shared",
303+
[](GluonOpBuilder &self, bool bCluster) -> OpState {
304+
return self.create<ttng::FenceAsyncSharedOp>(bCluster);
305+
})
306+
302307
.def("create_broadcast",
303308
[](TritonOpBuilder &self, Value &arg, Type retTy) -> Value {
304309
return self.create<tt::BroadcastOp>(retTy, arg);

python/src/ir.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,12 +1425,7 @@ void init_triton_ir(py::module &&m) {
14251425
})
14261426
.def("create_expand_dims",
14271427
[](TritonOpBuilder &self, Value &arg, int axis) -> Value {
1428-
auto argType = dyn_cast<RankedTensorType>(arg.getType());
1429-
auto argEltType = argType.getElementType();
1430-
std::vector<int64_t> retShape = argType.getShape();
1431-
retShape.insert(retShape.begin() + axis, 1);
1432-
return self.create<ExpandDimsOp>(
1433-
RankedTensorType::get(retShape, argEltType), arg, axis);
1428+
return self.create<ExpandDimsOp>(arg, axis);
14341429
})
14351430
.def("create_cat",
14361431
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {

0 commit comments

Comments
 (0)