Skip to content

Commit e461a5b

Browse files
Pawel/hoist unpipelineable operands (#7082)
Hoist tmem and smem allocations for mmav5 operands that may have wait pushed to the next stage despite operands not being pipelined. This is to avoid bug where tmem/smem allocations are being clobbered by allocator due to liveranges being implicitly longer than the allocator is able to see. To support hoistedproper wait placement in the presence of hoisted allocs (thus having stores instead of allocs in the loop) the `MMAv5PipelineableOperandsHelper` was gneralized a bit to look for any ops that override operand memory, not only load. Fence insertion pass also had to be generalized a bit to detect that shared memory is being overwritten with stores.
1 parent e3ac59c commit e461a5b

File tree

7 files changed

+183
-52
lines changed

7 files changed

+183
-52
lines changed

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,19 @@ class MMAv5PipelineableOperandsHelper {
3838
: mmaOp(mmaOp), forOp(forOp), isLoadToBePipelined(isLoadToBePipelined) {
3939
run();
4040
}
41+
4142
bool isPipelineable = false;
4243
// If true, the existing operand loads are all been found and their
4344
// pipelineability has been determined.
4445
bool isOperandsStateDetermined = false;
45-
SmallVector<Operation *> unpipelineableOperandLoads;
46+
SmallVector<Operation *> unpipelineableOperandDefs;
4647

4748
private:
4849
MMAv5OpInterface mmaOp;
4950
scf::ForOp forOp;
5051
std::function<bool(Operation *)> isLoadToBePipelined;
51-
bool comesFromLoadOrOutsideLoop(Value v, Operation *&foundLoad);
5252
void run();
53+
bool isOperandPipelineable(Value v, Operation *&foundDef);
5354
};
5455

5556
//===----------------------------------------------------------------------===//

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,9 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
636636

637637
// Make sure all ops have attributes.
638638
for (Operation &op : forOp.getBody()->without_terminator()) {
639+
if (!schedule.count(&op)) {
640+
op.emitError() << "op not found in the schedule";
641+
}
639642
assert(schedule.count(&op) && "op not found in the schedule");
640643
}
641644
return forOp;
@@ -796,6 +799,41 @@ getTmemUseStageBoundOps(ttng::TMEMAllocOp alloc, scf::ForOp forOp,
796799
return bounds;
797800
}
798801

802+
Operation *hoistBufferOutOfLoop(scf::ForOp forOp, Operation *op,
803+
CoarseSchedule &schedule) {
804+
Operation *newStore = nullptr;
805+
if (!isa<ttng::TMEMAllocOp, ttg::LocalAllocOp>(op))
806+
return nullptr;
807+
// If the alloc is already out of the loop, there is nothing to do.
808+
if (!forOp->isAncestor(op))
809+
return nullptr;
810+
OpBuilderForStage builder(op->getLoc(), forOp, schedule);
811+
auto allocType = dyn_cast<MemDescType>(op->getResult(0).getType());
812+
auto newType = triton::gpu::MemDescType::get(
813+
allocType.getShape(), allocType.getElementType(), allocType.getEncoding(),
814+
allocType.getMemorySpace(),
815+
/*mutableMemory=*/true);
816+
auto newAlloc = builder.clone(*op);
817+
newAlloc->getResult(0).setType(newType);
818+
builder.setStageCluster(schedule[op]);
819+
if (auto tmemAlloc = dyn_cast<ttng::TMEMAllocOp>(newAlloc)) {
820+
tmemAlloc.getSrcMutable().clear();
821+
builder.setInsertionPointAfter(op);
822+
Value trueVal = builder.create<arith::ConstantIntOp>(1, 1);
823+
newStore = builder.create<ttng::TMEMStoreOp>(tmemAlloc.getResult(),
824+
op->getOperand(0), trueVal);
825+
} else {
826+
auto localAlloc = cast<ttg::LocalAllocOp>(newAlloc);
827+
localAlloc.getSrcMutable().clear();
828+
builder.setInsertionPointAfter(op);
829+
newStore = builder.create<ttg::LocalStoreOp>(op->getOperand(0),
830+
localAlloc.getResult());
831+
}
832+
op->replaceAllUsesWith(newAlloc);
833+
op->erase();
834+
return newStore;
835+
}
836+
799837
void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
800838
ttng::MMAv5OpInterface mma, int mmaSelfLatency,
801839
ttng::TMEMAllocOp alloc, int phaseArgIdx,
@@ -818,13 +856,24 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
818856

819857
ttng::MMAv5PipelineableOperandsHelper mmaPipeHelper(mma, forOp,
820858
isLoadToBePipelined);
859+
860+
SmallVector<Operation *> updatedDefs;
861+
for (auto def : mmaPipeHelper.unpipelineableOperandDefs) {
862+
auto newStore = hoistBufferOutOfLoop(forOp, def, schedule);
863+
if (newStore) {
864+
updatedDefs.push_back(newStore);
865+
} else {
866+
updatedDefs.push_back(def);
867+
}
868+
}
869+
821870
if (!mmaPipeHelper.isPipelineable &&
822871
mmaPipeHelper.isOperandsStateDetermined) {
823872
// If the operands are not pipelineable, we need to insert a sync point
824873
// before the earliest operand load
825-
for (auto load : mmaPipeHelper.unpipelineableOperandLoads) {
826-
if (!latestSyncPoint || schedule.isOpBefore(load, *latestSyncPoint)) {
827-
latestSyncPoint = load;
874+
for (auto def : updatedDefs) {
875+
if (!latestSyncPoint || schedule.isOpBefore(def, *latestSyncPoint)) {
876+
latestSyncPoint = def;
828877
}
829878
}
830879
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ namespace ttng = mlir::triton::nvidia_gpu;
1414
// MMA Pipeline Analysis
1515
//===----------------------------------------------------------------------===//
1616

17-
bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
18-
Value v, Operation *&foundLoad) {
17+
bool ttng::MMAv5PipelineableOperandsHelper::isOperandPipelineable(
18+
Value v, Operation *&foundDef) {
1919
if (forOp.isDefinedOutsideOfLoop(v)) {
2020
return true;
2121
}
@@ -25,14 +25,16 @@ bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
2525
while (isa<ttg::MemDescTransOp, ttg::MemDescReshapeOp>(v.getDefiningOp())) {
2626
v = v.getDefiningOp()->getOperand(0);
2727
}
28-
if (auto tmemAlloc = dyn_cast<ttng::TMEMAllocOp>(v.getDefiningOp())) {
29-
foundLoad = tmemAlloc;
28+
if (isa<ttg::LocalStoreOp, ttng::TMEMStoreOp, ttng::TMEMAllocOp>(
29+
v.getDefiningOp())) {
30+
foundDef = v.getDefiningOp();
3031
return false;
3132
}
3233
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(v.getDefiningOp());
3334
if (!localAlloc) {
3435
return false;
3536
}
37+
foundDef = localAlloc;
3638
if (!localAlloc.getSrc()) {
3739
return false;
3840
}
@@ -44,17 +46,18 @@ bool ttng::MMAv5PipelineableOperandsHelper::comesFromLoadOrOutsideLoop(
4446
localAllocSrc)) {
4547
return false;
4648
}
47-
foundLoad = localAllocSrc;
48-
if (!isLoadToBePipelined(foundLoad)) {
49+
foundDef = localAllocSrc;
50+
if (!isLoadToBePipelined(localAllocSrc)) {
4951
return false;
5052
}
51-
if (canBeAsyncLoad(foundLoad)) {
53+
if (canBeAsyncLoad(localAllocSrc)) {
5254
return true;
5355
}
5456
return false;
5557
}
5658

5759
void ttng::MMAv5PipelineableOperandsHelper::run() {
60+
unpipelineableOperandDefs.clear();
5861
isOperandsStateDetermined = true;
5962
// Accumulator alloc must be outside the loop.
6063
auto tmemAlloc = mmaOp.getAccumulator().getDefiningOp<ttng::TMEMAllocOp>();
@@ -65,17 +68,17 @@ void ttng::MMAv5PipelineableOperandsHelper::run() {
6568
return;
6669
}
6770
if (auto dotOp = dyn_cast<tt::DotOpInterface>(mmaOp.getOperation())) {
68-
Operation *foundLoad = nullptr;
69-
if (!comesFromLoadOrOutsideLoop(dotOp.getA(), foundLoad)) {
70-
if (foundLoad) {
71-
unpipelineableOperandLoads.push_back(foundLoad);
71+
Operation *foundDef = nullptr;
72+
if (!isOperandPipelineable(dotOp.getA(), foundDef)) {
73+
if (foundDef) {
74+
unpipelineableOperandDefs.push_back(foundDef);
7275
} else {
7376
isOperandsStateDetermined = false;
7477
}
7578
}
76-
if (!comesFromLoadOrOutsideLoop(dotOp.getB(), foundLoad)) {
77-
if (foundLoad) {
78-
unpipelineableOperandLoads.push_back(foundLoad);
79+
if (!isOperandPipelineable(dotOp.getB(), foundDef)) {
80+
if (foundDef) {
81+
unpipelineableOperandDefs.push_back(foundDef);
7982
} else {
8083
isOperandsStateDetermined = false;
8184
}
@@ -95,24 +98,24 @@ void ttng::MMAv5PipelineableOperandsHelper::run() {
9598
isOperandsStateDetermined = false;
9699
return;
97100
}
98-
Operation *foundLoad = nullptr;
99-
if (!comesFromLoadOrOutsideLoop(scaledOp.getAScale(), foundLoad)) {
100-
if (foundLoad) {
101-
unpipelineableOperandLoads.push_back(foundLoad);
101+
Operation *foundDef = nullptr;
102+
if (!isOperandPipelineable(scaledOp.getAScale(), foundDef)) {
103+
if (foundDef) {
104+
unpipelineableOperandDefs.push_back(foundDef);
102105
} else {
103106
isOperandsStateDetermined = false;
104107
}
105108
}
106-
if (!comesFromLoadOrOutsideLoop(scaledOp.getBScale(), foundLoad)) {
107-
if (foundLoad) {
108-
unpipelineableOperandLoads.push_back(foundLoad);
109+
if (!isOperandPipelineable(scaledOp.getBScale(), foundDef)) {
110+
if (foundDef) {
111+
unpipelineableOperandDefs.push_back(foundDef);
109112
} else {
110113
isOperandsStateDetermined = false;
111114
}
112115
}
113116
}
114117
isPipelineable =
115-
isOperandsStateDetermined && unpipelineableOperandLoads.empty();
118+
isOperandsStateDetermined && unpipelineableOperandDefs.empty();
116119
}
117120

118121
bool ttng::hasAccReadModifyWrite(ttng::MMAv5OpInterface mma, scf::ForOp forOp) {

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ struct FenceInsertionPass
3939
mod.walk([&](DotOpInterface dotOp) {
4040
Value a = dotOp.getA();
4141
Value b = dotOp.getB();
42-
bool aDependsOnShared = dependOnCopyRegToShared(a);
43-
bool bDependsOnShared = dependOnCopyRegToShared(b);
44-
if (!aDependsOnShared && !bDependsOnShared)
42+
SmallVector<Operation *> copyRegToSharedOpsA = findCopyRegToSharedOps(a);
43+
SmallVector<Operation *> copyRegToSharedOpsB = findCopyRegToSharedOps(b);
44+
if (copyRegToSharedOpsA.empty() && copyRegToSharedOpsB.empty())
4545
return WalkResult::advance();
4646

4747
OpBuilder builder(dotOp);
@@ -50,11 +50,13 @@ struct FenceInsertionPass
5050
// If there is all the dependencies are outside of the loop try to hoist
5151
// the fence.
5252
while (auto loopOp = fence->getParentOfType<LoopLikeOpInterface>()) {
53-
if (aDependsOnShared &&
54-
loopOp->isAncestor(a.getParentBlock()->getParentOp()))
53+
if (!copyRegToSharedOpsA.empty() &&
54+
llvm::any_of(copyRegToSharedOpsA,
55+
[&](Operation *op) { return loopOp->isAncestor(op); }))
5556
break;
56-
if (bDependsOnShared &&
57-
loopOp->isAncestor(b.getParentBlock()->getParentOp()))
57+
if (!copyRegToSharedOpsB.empty() &&
58+
llvm::any_of(copyRegToSharedOpsB,
59+
[&](Operation *op) { return loopOp->isAncestor(op); }))
5860
break;
5961
loopOp.moveOutOfLoop(fence);
6062
}
@@ -72,31 +74,47 @@ struct FenceInsertionPass
7274

7375
private:
7476
// Return true if the operand depends on a copy from register to shared.
75-
bool dependOnCopyRegToShared(Value operand) {
77+
SmallVector<Operation *> findCopyRegToSharedOps(Value operand) {
7678
DenseSet<Value> visited;
77-
return dependOnCopyRegToShared(operand, visited);
79+
llvm::SetVector<Operation *> result;
80+
findCopyRegToSharedOps(operand, visited, result);
81+
return result.takeVector();
7882
}
7983

80-
bool dependOnCopyRegToShared(Value operand, DenseSet<Value> &visited) {
84+
void findCopyRegToSharedOps(Value operand, DenseSet<Value> &visited,
85+
llvm::SetVector<Operation *> &result) {
8186
// If the value has already been visited we can safely return false as we
8287
// would early return when true.
8388
if (visited.count(operand))
84-
return false;
89+
return;
8590
visited.insert(operand);
8691
if (!isa<triton::gpu::MemDescType>(operand.getType()))
87-
return false;
92+
return;
8893

8994
auto op = operand.getDefiningOp();
9095
if (op) {
9196
// reach an alloc copying from register, we need a fence.
92-
if (isa<ttg::LocalAllocOp>(op) && cast<ttg::LocalAllocOp>(op).getSrc())
93-
return true;
97+
if (auto localAlloc = dyn_cast<ttg::LocalAllocOp>(op)) {
98+
if (localAlloc.getSrc()) {
99+
result.insert(op);
100+
}
101+
// Check if there are local_store ops that write to that buffer.
102+
for (auto user : localAlloc.getResult().getUsers()) {
103+
while (user->hasOneUse() &&
104+
user->hasTrait<OpTrait::MemDescViewTrait>()) {
105+
user = *user->getUsers().begin();
106+
}
107+
if (isa<ttg::LocalStoreOp>(user)) {
108+
result.insert(user);
109+
return;
110+
}
111+
}
112+
}
94113
// if it is not an alloc, iterate over the operands.
95114
for (auto v : op->getOperands()) {
96-
if (dependOnCopyRegToShared(v))
97-
return true;
115+
findCopyRegToSharedOps(v, visited, result);
98116
}
99-
return false;
117+
return;
100118
}
101119

102120
// reach BlockArgument
@@ -108,22 +126,23 @@ struct FenceInsertionPass
108126
assert(argNum != 0 && "induction var cannot be memdesc type");
109127
--argNum;
110128
// prologue
111-
if (dependOnCopyRegToShared(forOp.getInitArgs()[argNum], visited))
112-
return true;
129+
findCopyRegToSharedOps(forOp.getInitArgs()[argNum], visited, result);
113130
// yield
114131
auto yieldOp = forOp.getBody()->getTerminator();
115132
Value v = yieldOp->getOperand(argNum);
116-
return dependOnCopyRegToShared(v, visited);
133+
findCopyRegToSharedOps(v, visited, result);
134+
return;
117135
}
118136

119137
// look through `ttg.warp_specialize`.
120138
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(argOwner)) {
121-
return dependOnCopyRegToShared(
122-
wsOp.getParentOp().getExplicitCaptures()[argNum]);
139+
findCopyRegToSharedOps(wsOp.getParentOp().getExplicitCaptures()[argNum],
140+
visited, result);
141+
return;
123142
}
124143

125144
// Conservatively return true for other ops
126-
return true;
145+
result.insert(argOwner);
127146
}
128147
};
129148

test/TritonGPU/fence-inserstion.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,27 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
2020

2121
// -----
2222

23+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
24+
#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
25+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}>
26+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
27+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
28+
#smem = #ttg.shared_memory
29+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
30+
// CHECK-LABEL: matmul_like_fence_local_store
31+
tt.func public @matmul_like_fence_local_store(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked2>) {
32+
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
33+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
34+
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>
35+
ttg.local_store %arg0, %0 : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
36+
// CHECK: ttng.fence_async_shared
37+
%2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf32, #mma>
38+
tt.return
39+
}
40+
}
41+
42+
// -----
43+
2344
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
2445
#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [1, 0]}>
2546
#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
@@ -74,6 +95,37 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
7495

7596
// -----
7697

98+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
99+
#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
100+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}>
101+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
102+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
103+
#smem = #ttg.shared_memory
104+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
105+
// CHECK-LABEL: fence_store_in_loop
106+
tt.func public @fence_store_in_loop(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x64xf16, #blocked>) {
107+
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
108+
%c64_i32 = arith.constant 64 : i32
109+
%c0_i32 = arith.constant 0 : i32
110+
%c32_i32 = arith.constant 32 : i32
111+
%0 = ttg.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
112+
%1 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared1, #smem>
113+
// CHECK-NOT: ttng.fence_async_shared
114+
// CHECK: scf.for
115+
// CHECK: ttng.fence_async_shared
116+
// CHECK: ttng.warp_group_dot
117+
scf.for %iv0 = %c0_i32 to %c64_i32 step %c32_i32 : i32 {
118+
scf.for %iv1 = %c0_i32 to %c64_i32 step %c32_i32 : i32 {
119+
ttg.local_store %arg0, %0 : tensor<128x128xf16, #blocked> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
120+
%2 = ttng.warp_group_dot %0, %1, %cst : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x64xf16, #shared1, #smem> -> tensor<128x64xf32, #mma>
121+
}
122+
}
123+
tt.return
124+
}
125+
}
126+
127+
// -----
128+
77129
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
78130
#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
79131
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 64, 16]}>

0 commit comments

Comments
 (0)