Skip to content

Commit 7a342f2

Browse files
authored
[Pipeliner] Fix backward scheduling over ttg.local_load (#7194)
A small bug was caused because backwards stage prop after warp specialization was assuming all latency ops have stages assigned. This is the correct thing to assume but right now the stage can sometimes be dropped for various reasons. Workaround the problem (for now) by ignoring those ops. This PR also tightens `const` API for CoarseSchedule to help catch bugs like this.
1 parent b655ab7 commit 7a342f2

File tree

4 files changed

+56
-26
lines changed

4 files changed

+56
-26
lines changed

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class CoarseSchedule {
4545
const_iterator begin() const { return orderClusters.begin(); }
4646
iterator end() { return orderClusters.end(); }
4747
const_iterator end() const { return orderClusters.end(); }
48-
size_t size() { return orderClusters.size(); }
48+
size_t size() const { return orderClusters.size(); }
4949
iterator newAtBack() {
5050
orderClusters.push_back(orderClusters.size());
5151
return std::prev(orderClusters.end());
@@ -88,7 +88,7 @@ class CoarseSchedule {
8888
DenseMap<Operation *, std::pair<int, Cluster>> opToStageAndCluster;
8989

9090
void setNumStages(int numStages) { this->numStages = numStages; }
91-
int getNumStages() { return numStages; }
91+
int getNumStages() const { return numStages; }
9292

9393
void insert(Operation *op, int stage, Cluster cluster) {
9494
if (stage >= numStages) {
@@ -115,7 +115,7 @@ class CoarseSchedule {
115115

116116
void erase(Operation *op) { opToStageAndCluster.erase(op); }
117117

118-
int count(Operation *op) { return opToStageAndCluster.count(op); }
118+
int count(Operation *op) const { return opToStageAndCluster.count(op); }
119119

120120
std::pair<int, Cluster> operator[](Operation *op) {
121121
return opToStageAndCluster[op];
@@ -129,25 +129,25 @@ class CoarseSchedule {
129129
Cluster splitClusterBefore(Operation *op, scf::ForOp forOp);
130130

131131
// Check if op a will show up before op b in the final unrolled code.
132-
bool isOpBefore(Operation *a, Operation *b);
132+
bool isOpBefore(Operation *a, Operation *b) const;
133133

134134
// Check if op a is in earlier cluster than op b.
135-
bool isOpInEarlierCluster(Operation *a, Operation *b);
135+
bool isOpInEarlierCluster(Operation *a, Operation *b) const;
136136

137137
// Check if op a is in the same cluster as op b.
138-
bool isOpInSameCluster(Operation *a, Operation *b);
138+
bool isOpInSameCluster(Operation *a, Operation *b) const;
139139

140140
SmallVector<std::tuple<Operation *, int, Cluster>>
141-
getOpsInOrder(scf::ForOp forOp);
141+
getOpsInOrder(scf::ForOp forOp) const;
142142
std::vector<std::pair<Operation *, unsigned>>
143-
createFinalSchedule(scf::ForOp forOp);
143+
createFinalSchedule(scf::ForOp forOp) const;
144144

145145
bool empty() const { return opToStageAndCluster.size() == 0; }
146146
auto end() const { return opToStageAndCluster.end(); }
147147
auto begin() const { return opToStageAndCluster.begin(); }
148148

149149
// Set <stage, cluster> based on CoarseSchedule.
150-
void serialize(scf::ForOp &forOp);
150+
void serialize(scf::ForOp &forOp) const;
151151
// Create a CoarseSchedule based on forOp's <stage, cluster>.
152152
LogicalResult deSerialize(scf::ForOp &forOp);
153153

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ tt::CoarseSchedule::splitClusterBefore(Operation *op, scf::ForOp forOp) {
123123
}
124124

125125
// Check if op a will show up before op b in the final unrolled code.
126-
bool tt::CoarseSchedule::isOpBefore(Operation *a, Operation *b) {
126+
bool tt::CoarseSchedule::isOpBefore(Operation *a, Operation *b) const {
127127
assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) &&
128128
"Operations must be in the schedule");
129-
auto [aStage, aCluster] = opToStageAndCluster[a];
130-
auto [bStage, bCluster] = opToStageAndCluster[b];
129+
auto [aStage, aCluster] = opToStageAndCluster.at(a);
130+
auto [bStage, bCluster] = opToStageAndCluster.at(b);
131131
if (aStage != bStage) {
132132
return aStage < bStage;
133133
}
@@ -137,21 +137,22 @@ bool tt::CoarseSchedule::isOpBefore(Operation *a, Operation *b) {
137137
return a->isBeforeInBlock(b);
138138
}
139139

140-
bool tt::CoarseSchedule::isOpInEarlierCluster(Operation *a, Operation *b) {
140+
bool tt::CoarseSchedule::isOpInEarlierCluster(Operation *a,
141+
Operation *b) const {
141142
assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) &&
142143
"Operations must be in the schedule");
143-
return clusters.isBefore(opToStageAndCluster[a].second,
144-
opToStageAndCluster[b].second);
144+
return clusters.isBefore(opToStageAndCluster.at(a).second,
145+
opToStageAndCluster.at(b).second);
145146
}
146147

147-
bool tt::CoarseSchedule::isOpInSameCluster(Operation *a, Operation *b) {
148+
bool tt::CoarseSchedule::isOpInSameCluster(Operation *a, Operation *b) const {
148149
assert(opToStageAndCluster.count(a) && opToStageAndCluster.count(b) &&
149150
"Operations must be in the schedule");
150-
return opToStageAndCluster[a].second == opToStageAndCluster[b].second;
151+
return opToStageAndCluster.at(a).second == opToStageAndCluster.at(b).second;
151152
}
152153

153154
SmallVector<std::tuple<Operation *, int, tt::CoarseSchedule::Cluster>>
154-
tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) {
155+
tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) const {
155156
SmallVector<SmallVector<std::tuple<Operation *, int, Cluster>>, 8>
156157
orderClusters(clusters.size());
157158
for (auto &op : forOp.getBody()->without_terminator()) {
@@ -160,12 +161,11 @@ tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) {
160161
continue;
161162
}
162163
auto [stage, cluster] = it->second;
163-
if (cluster == Cluster{}) {
164-
continue;
165-
}
164+
assert(cluster != Cluster{} && "Op with invalid cluster!");
166165
assert(stage < numStages && "Op with invalid stage!");
167166
int clusterId = *cluster;
168-
assert(clusterId == std::distance(clusters.begin(), cluster) &&
167+
assert(clusterId == std::distance(clusters.begin(),
168+
ClusterList::const_iterator(cluster)) &&
169169
"Cluster ID mismatch!");
170170
orderClusters[clusterId].push_back(make_tuple(&op, stage, cluster));
171171
}
@@ -180,7 +180,7 @@ tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) {
180180
}
181181

182182
std::vector<std::pair<Operation *, unsigned>>
183-
tt::CoarseSchedule::createFinalSchedule(scf::ForOp forOp) {
183+
tt::CoarseSchedule::createFinalSchedule(scf::ForOp forOp) const {
184184
SmallVector<std::tuple<Operation *, int, tt::CoarseSchedule::Cluster>>
185185
opsInOrder = getOpsInOrder(forOp);
186186
std::vector<std::pair<Operation *, unsigned>> schedule;
@@ -248,7 +248,7 @@ static std::optional<int> tryGetMaxStage(scf::ForOp &forOp) {
248248
}
249249

250250
// Set <stage, cluster> based on CoarseSchedule.
251-
void tt::CoarseSchedule::serialize(scf::ForOp &forOp) {
251+
void tt::CoarseSchedule::serialize(scf::ForOp &forOp) const {
252252
for (auto [op, stage, cluster] : getOpsInOrder(forOp)) {
253253
setStageCluster(op, stage, *cluster);
254254
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,11 @@ CoarseSchedule getInitialSchedule(scf::ForOp forOp,
184184
// assigned to the same stage.
185185
DenseSet<int> latencyStages;
186186
auto ops = forOp.getBody()->without_terminator();
187-
for (Operation &op : llvm::make_filter_range(ops, isLatencyOp))
188-
latencyStages.insert(schedule[&op].first);
187+
for (Operation &op : llvm::make_filter_range(ops, isLatencyOp)) {
188+
// FIXME: This should assert all latency ops have an assigned stage.
189+
if (schedule.count(&op))
190+
latencyStages.insert(schedule[&op].first);
191+
}
189192
if (latencyStages.size() <= 1) {
190193
CoarseSchedule normalized(/*numStages=*/1);
191194
auto cluster = normalized.clusters.newAtFront();

test/TritonGPU/pipeline-schedule-loop.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,3 +841,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
841841
tt.return %2 : tensor<128x128xf16, #blocked1>
842842
}
843843
}
844+
845+
// -----
846+
847+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
848+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
849+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}>
850+
#smem = #ttg.shared_memory
851+
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
852+
853+
module attributes {"ttg.num-warps" = 4 : i32} {
854+
855+
// CHECK-LABEL: @backwards_prop_existing
856+
tt.func public @backwards_prop_existing(%arg0: i32, %arg1: tensor<128x4x!tt.ptr<i8>, #blocked>) {
857+
%c0_i32 = arith.constant 0 : i32
858+
%c1_i32 = arith.constant 1 : i32
859+
scf.for %arg2 = %c0_i32 to %arg0 step %c1_i32 : i32 {
860+
%0 = tt.load %arg1 {loop.cluster = 2 : i32, loop.stage = 3 : i32} : tensor<128x4x!tt.ptr<i8>, #blocked>
861+
%1 = ttg.local_alloc %0 : (tensor<128x4xi8, #blocked>) -> !ttg.memdesc<128x4xi8, #shared, #smem>
862+
// CHECK: ttg.local_load %{{.*}} {loop.cluster = 0 : i32, loop.stage = 0 : i32}
863+
%2 = ttg.local_load %1 : !ttg.memdesc<128x4xi8, #shared, #smem> -> tensor<128x4xi8, #linear>
864+
%result = ttng.tmem_alloc %2 {loop.cluster = 2 : i32, loop.stage = 3 : i32} : (tensor<128x4xi8, #linear>) -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>
865+
"use"(%result) {loop.cluster = 2 : i32, loop.stage = 3 : i32} : (!ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>) -> ()
866+
} {tt.scheduled_max_stage = 3 : i32, tt.warp_specialize}
867+
tt.return
868+
}
869+
870+
}

0 commit comments

Comments
 (0)