Skip to content

Commit e7072a3

Browse files
authored
[Pipeliner] Make loop flattening and scheduler more robust (#5902)
Loop flattening: don't forward all outer loop results through the generated epilogue to avoid creating dependencies on the prologue. In scheduling: overwrite the stage and clusters of ops if it was requested that they be scheduled earlier than previously. This can happen some times due to the order in which certain ops have their dependencies scheduled. Otherwise, it is guaranteed that the generated schedule is invalid.
1 parent 28396a7 commit e7072a3

File tree

6 files changed

+180
-30
lines changed

6 files changed

+180
-30
lines changed

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -106,27 +106,10 @@ class CoarseSchedule {
106106
return true;
107107
}
108108

109-
void insertMinimum(Operation *op, int stage, Cluster cluster) {
110-
auto res = opToStageAndCluster.insert({op, {stage, cluster}});
111-
if (res.second) {
112-
return;
113-
}
114-
auto &[existingStage, existingCluster] = res.first->second;
115-
existingStage = std::min(stage, existingStage);
116-
117-
// If existingCluster is reachable from cluster,
118-
// then cluster is earlier in the list
119-
auto it = cluster;
120-
for (auto it = cluster; it != clusters.end(); ++it) {
121-
if (it == existingCluster) {
122-
existingCluster = cluster;
123-
return;
124-
}
125-
}
126-
}
109+
bool insertMinimum(Operation *op, int stage, Cluster cluster);
127110

128111
bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
129-
bool includeArg);
112+
bool includeArg, bool insertIfEarlier = false);
130113

131114
void erase(Operation *op) { opToStageAndCluster.erase(op); }
132115

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -800,25 +800,38 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) {
800800
// if T == len_j0 + len_j1 + ... + len_jN - N - 1:
801801
// epilogue(i)
802802
Logue &epilogue = logues.back();
803+
804+
// The only possible use of an epilogue output is the yield.
805+
auto outerYield = cast<scf::YieldOp>(outer.getBody()->getTerminator());
806+
SmallVector<Value> usedIterArgs;
807+
for (Value output : epilogue.getOutputs()) {
808+
for (OpOperand &use : output.getUses()) {
809+
if (use.getOwner() == outerYield) {
810+
usedIterArgs.push_back(fused.getRegionIterArgs().drop_front(
811+
outerArgsStartIdx)[use.getOperandNumber()]);
812+
}
813+
}
814+
}
815+
803816
auto epilogueCond =
804817
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, T,
805818
b.create<arith::SubIOp>(innerLen, intTyCst(1)));
806819
auto epilogueIf =
807-
b.create<scf::IfOp>(outer.getYieldedValues().getTypes(), epilogueCond);
820+
b.create<scf::IfOp>(epilogue.getOutputTypes(), epilogueCond);
808821

809822
Block *thenBlock = b.createBlock(&epilogueIf.getThenRegion());
810823
epilogue.moveBefore(thenBlock, thenBlock->end());
811824

812825
b.setInsertionPointToEnd(thenBlock);
813-
b.create<scf::YieldOp>(outer.getYieldedValues());
814-
826+
b.create<scf::YieldOp>(epilogue.getOutputs());
815827
b.createBlock(&epilogueIf.getElseRegion());
816-
b.create<scf::YieldOp>(fused.getRegionIterArgs().slice(
817-
outerArgsStartIdx, outer.getNumRegionIterArgs()));
828+
b.create<scf::YieldOp>(usedIterArgs);
829+
epilogue.replaceAllUsesWith(epilogueIf.getResults(),
830+
epilogueIf.getThenRegion());
818831

819832
// Finally, create the yield of the fused loop.
820833
SmallVector<Value> outerOuts{T, i};
821-
llvm::append_range(outerOuts, epilogueIf.getResults());
834+
llvm::append_range(outerOuts, outerYield.getOperands());
822835
for (scf::IfOp bodyIf : bodyIfs)
823836
outerOuts.push_back(/*jk=*/bodyIf.getResult(0));
824837
for (auto [bodyIf, loop] : llvm::zip(bodyIfs, innerLoops)) {
@@ -831,7 +844,7 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) {
831844
}
832845

833846
b.setInsertionPointToEnd(fused.getBody());
834-
auto outerYield = b.create<scf::YieldOp>(outerOuts);
847+
b.create<scf::YieldOp>(outerOuts);
835848
outer.replaceAllUsesWith(
836849
fused.getResults().slice(outerArgsStartIdx, outer.getNumResults()));
837850

lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,17 @@ void scheduleDistanceOneDependencies(scf::ForOp forOp,
174174
// Exception: Schedule loads with a distance of 1 together
175175
// with the current op.
176176
schedule.insertIfAbsent(defOp, stage, cluster);
177-
schedule.insertDepsOfOp(defOp, stage, cluster, true);
177+
schedule.insertDepsOfOp(defOp, stage, cluster,
178+
/*includeArg=*/true,
179+
/*insertIfEarlier=*/true);
178180
} else {
179181
if (dist1Cluster.count(&cluster) == 0) {
180182
dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster);
181183
}
182184
schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]);
183185
schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster],
184-
true);
186+
/*includeArg=*/true,
187+
/*includeIfEarlier=*/true);
185188
}
186189
}
187190
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,51 @@ using namespace mlir;
1414
namespace tt = mlir::triton;
1515
namespace ttg = mlir::triton::gpu;
1616

17+
bool tt::CoarseSchedule::insertMinimum(Operation *op, int stage,
18+
Cluster cluster) {
19+
auto res = opToStageAndCluster.insert({op, {stage, cluster}});
20+
if (res.second) {
21+
return true;
22+
}
23+
24+
auto &[existingStage, existingCluster] = res.first->second;
25+
26+
// Always insert if the stage is earlier.
27+
if (stage < existingStage) {
28+
existingStage = stage;
29+
existingCluster = cluster;
30+
return true;
31+
}
32+
33+
// If the stage is later, no change.
34+
if (stage > existingStage) {
35+
return false;
36+
}
37+
38+
// If existingCluster is reachable from cluster,
39+
// then cluster is earlier in the list
40+
auto it = cluster;
41+
for (auto it = cluster; it != clusters.end(); ++it) {
42+
if (it == existingCluster) {
43+
existingCluster = cluster;
44+
return true;
45+
}
46+
}
47+
48+
// Didn't change the cluster.
49+
return false;
50+
}
51+
1752
bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage,
1853
tt::CoarseSchedule::Cluster cluster,
19-
bool includeArg) {
54+
bool includeArg, bool insertIfEarlier) {
55+
auto tryInsert = [&](Operation *op, int stage,
56+
tt::CoarseSchedule::Cluster cluster) {
57+
if (!insertIfEarlier)
58+
return insertIfAbsent(op, stage, cluster);
59+
return insertMinimum(op, stage, cluster);
60+
};
61+
2062
bool inserted = false;
2163
for (Value operand : getNestedOperands(op)) {
2264
Value v = operand;
@@ -35,7 +77,7 @@ bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage,
3577
}
3678
Operation *defOp = v.getDefiningOp();
3779
if (defOp && defOp->getBlock() == op->getBlock()) {
38-
if (insertIfAbsent(defOp, stage, cluster)) {
80+
if (tryInsert(defOp, stage, cluster)) {
3981
inserted = true;
4082
insertDepsOfOp(defOp, stage, cluster, includeArg);
4183
}

test/TritonGPU/fuse-nested-loops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,30 @@ tt.func @sink_prologue_to_epilogue(%ub: i32) {
540540

541541
tt.return
542542
}
543+
544+
// -----
545+
546+
// CHECK-LABEL: @prologue_output
547+
tt.func @prologue_output(%ub: i32) {
548+
%c0_i32 = arith.constant 0 : i32
549+
%c1_i32 = arith.constant 1 : i32
550+
551+
// CHECK: scf.for
552+
%0 = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%k = %c0_i32) -> i32 : i32 {
553+
// CHECK: scf.if
554+
// CHECK: {increment}
555+
%next = arith.addi %k, %c1_i32 {increment} : i32
556+
// CHECK: scf.if
557+
scf.for %j = %c0_i32 to %ub step %c1_i32 : i32 {
558+
// CHECK-NEXT: "body"
559+
"body"(%i, %j) : (i32, i32) -> ()
560+
}
561+
// CHECK: scf.if {{%[0-9]+}} {
562+
// CHECK-NEXT: "epilogue"
563+
"epilogue"(%i) : (i32) -> ()
564+
// CHECK-NEXT: } else {
565+
scf.yield %next : i32
566+
} {"ttg.always-fuse"}
567+
568+
tt.return
569+
}

test/TritonGPU/loop-schedule.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,88 @@ tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index,
6262

6363
// -----
6464

65+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
66+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
67+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}>
68+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
69+
#smem = #ttg.shared_memory
70+
71+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
72+
73+
// CHECK-LABEL: @fused_loop
74+
tt.func public @fused_loop(%arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) {
75+
%c10_i32 = arith.constant 10 : i32
76+
%false = arith.constant false
77+
%0 = ub.poison : !tt.tensordesc<tensor<64x256xf16>>
78+
%cst = arith.constant dense<0> : tensor<128x1xi64, #blocked>
79+
%c-1_i32 = arith.constant -1 : i32
80+
%c1_i32 = arith.constant 1 : i32
81+
%c0_i32 = arith.constant 0 : i32
82+
%c64_i32 = arith.constant 64 : i32
83+
%c1_i64 = arith.constant 1 : i64
84+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
85+
86+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
87+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
88+
%3 = arith.extsi %arg7 : i32 to i64
89+
%4 = tt.make_tensor_descriptor %arg5, [%arg7, %arg7], [%3, %c1_i64] : <f16>, <tensor<64x256xf16>>
90+
%5 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked>
91+
%7 = tt.splat %3 : i64 -> tensor<128x1xi64, #blocked>
92+
93+
// CHECK: scf.for
94+
%8:9 = scf.for %arg29 = %c0_i32 to %arg7 step %c1_i32 iter_args(%arg30 = %c-1_i32, %arg31 = %4, %arg32 = %c0_i32, %arg33 = %arg5, %arg34 = %cst_0, %arg35 = %c0_i32, %arg36 = %cst, %arg37 = %0, %arg38 = %false) -> (i32, !tt.tensordesc<tensor<64x256xf16>>, i32, !tt.ptr<f16>, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc<tensor<64x256xf16>>, i1) : i32 {
95+
%9 = arith.addi %arg30, %c1_i32 : i32
96+
%10 = arith.cmpi eq, %arg30, %c10_i32 : i32
97+
%11 = arith.select %10, %c0_i32, %9 : i32
98+
%12 = arith.cmpi eq, %11, %c0_i32 : i32
99+
100+
// This op is a distance 1 dependency of itself.
101+
// CHECK: {_test_marker_0, loop.cluster = 4 : i32, loop.stage = 0 : i32}
102+
%13 = arith.select %12, %c0_i32, %arg32 {_test_marker_0} : i32
103+
104+
%14 = arith.select %12, %arg31, %arg37 : !tt.tensordesc<tensor<64x256xf16>>
105+
%15 = arith.select %12, %c10_i32, %arg35 : i32
106+
%16 = scf.if %12 -> (tensor<128x1xi64, #blocked>) {
107+
%32 = arith.muli %cst, %7 : tensor<128x1xi64, #blocked>
108+
scf.yield %32 : tensor<128x1xi64, #blocked>
109+
} else {
110+
scf.yield %arg36 : tensor<128x1xi64, #blocked>
111+
}
112+
%17 = tt.splat %arg33 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
113+
%18 = tt.addptr %17, %16 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi64, #blocked>
114+
%19 = tt.broadcast %18 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked>
115+
%20 = tt.addptr %19, %5 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi32, #blocked>
116+
%21 = tt.addptr %arg33, %c64_i32 : !tt.ptr<f16>, i32
117+
%22 = tt.load %20 : tensor<128x64x!tt.ptr<f16>, #blocked>
118+
%23 = ttg.local_alloc %22 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
119+
%24 = arith.muli %13, %c64_i32 : i32
120+
%25 = tt.experimental_descriptor_load %14[%24, %15] : !tt.tensordesc<tensor<64x256xf16>> -> tensor<64x256xf16, #blocked1>
121+
%26 = ttg.local_alloc %25 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem>
122+
%27 = ttng.warp_group_dot %23, %26, %arg34, %arg38 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma>
123+
%28 = arith.addi %13, %c1_i32 : i32
124+
125+
// This op is in the backward slice of `_test_marker_2` and the epilogue.
126+
// CHECK: {_test_marker_1, loop.cluster = 3 : i32, loop.stage = 1 : i32}
127+
%29 = arith.cmpi eq, %11, %c10_i32 {_test_marker_1} : i32
128+
129+
// CHECK: {_test_marker_2, loop.cluster = 3 : i32, loop.stage = 1 : i32}
130+
%30 = arith.select %29, %arg5, %21 {_test_marker_2} : !tt.ptr<f16>
131+
132+
%31 = arith.cmpi ne, %11, %c10_i32 : i32
133+
134+
scf.if %29 {
135+
"use"(%27) : (tensor<128x256xf32, #mma>) -> ()
136+
// CHECK: {_test_marker_3, loop.cluster = 5 : i32, loop.stage = 2 : i32}
137+
} {_test_marker_3}
138+
scf.yield %11, %14, %28, %30, %27, %15, %16, %14, %31 : i32, !tt.tensordesc<tensor<64x256xf16>>, i32, !tt.ptr<f16>, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc<tensor<64x256xf16>>, i1
139+
}
140+
tt.return
141+
}
142+
143+
}
144+
145+
// -----
146+
65147
// CHECK-LABEL: @prologue_backward_slice
66148
tt.func @prologue_backward_slice(%ub: i32, %cond: i1) {
67149
%c0_i32 = arith.constant 0 : i32

0 commit comments

Comments
 (0)