@@ -104,9 +104,7 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
104104 }
105105
106106 // Assign stage to each op reachable from a latency op
107- for (auto &kv : distance) {
108- Operation *op = kv.first ;
109- int dist = kv.second ;
107+ for (auto [op, dist] : distance) {
110108 // We only schedule ops that are downstream of a latency op
111109 // (had a non-negative distance due to a latency op).
112110 if (dist >= 0 )
@@ -120,16 +118,31 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
120118 for (int i = 0 ; i <= maxStage; i++) {
121119 clusters[i] = schedule.clusters .newAtBack ();
122120 }
123- CoarseSchedule::Cluster epilogue = schedule.clusters .newAtBack ();
124121 // Assign ops to the clusters in reverse-stage order;
125122 // ops with higher stage numbers are assigned first. This way we will
126123 // end up with roughly reverse program order in the clusters.
124+ for (auto [op, stage] : opToStage)
125+ schedule.insert (op, stage, clusters[maxStage - stage]);
126+
127+ // Move `scf.if` ops in the current schedule (forward slice of the latency
128+ // ops) into a new epilogue cluster at the end of the schedule, pushing them
129+ // as close to the end of the loop body as possible.
130+ CoarseSchedule::Cluster epilogue = schedule.clusters .newAtBack ();
127131 for (auto [op, stage] : opToStage) {
128- if (isa <scf::IfOp>(op)) {
129- schedule. insert (op, stage, epilogue);
132+ auto ifOp = dyn_cast <scf::IfOp>(op);
133+ if (!ifOp)
130134 continue ;
131- }
132- schedule.insert (op, stage, clusters[maxStage - stage]);
135+ // If the `scf.if` op itself is a latency op, skip it.
136+ if (opLatency.contains (ifOp))
137+ continue ;
138+ // Ensure this does not create scheduling conflicts by ensuring the forward
139+ // slice of the `scf.if` does not contain ops that are already scheduled, as
140+ // this will cause the `scf.if` to be scheduled after its dependents.
141+ SetVector<Operation *> slice;
142+ getForwardSlice (ifOp, &slice);
143+ if (llvm::any_of (slice, [&](Operation *op) { return opToStage.count (op); }))
144+ continue ;
145+ schedule.insert (ifOp, stage, epilogue);
133146 }
134147
135148 return schedule;
@@ -140,16 +153,6 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
140153void scheduleDistanceOneDependencies (scf::ForOp forOp,
141154 CoarseSchedule &schedule) {
142155 int numStages = schedule.numStages ;
143- auto getNestedOperands = [](Operation *op) -> SmallVector<Value> {
144- SmallVector<Value> operands;
145- op->walk ([&](Operation *nestedOp) {
146- for (Value operand : nestedOp->getOperands ()) {
147- if (operand.getParentBlock ()->getParentOp ()->isAncestor (nestedOp))
148- operands.push_back (operand);
149- }
150- });
151- return operands;
152- };
153156
154157 // Mapping from the cluster to the cluster before it.
155158 DenseMap<CoarseSchedule::Cluster *, CoarseSchedule::Cluster> dist1Cluster;
@@ -206,6 +209,7 @@ CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp,
206209 SetVector<Operation *> backwardSlice;
207210 BackwardSliceOptions opt;
208211 opt.omitBlockArguments = true ;
212+ opt.omitUsesFromAbove = false ;
209213 getBackwardSlice ((Operation *)op, &backwardSlice, opt);
210214
211215 for (auto op : backwardSlice) {
@@ -218,7 +222,7 @@ CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp,
218222 if (!ifsToStage.empty ()) {
219223 CoarseSchedule::Cluster prologueCluster = schedule.clusters .newAtFront ();
220224 for (auto [ifOp, stage] : ifsToStage) {
221- schedule.insert (ifOp, stage, prologueCluster);
225+ schedule.insertIfAbsent (ifOp, stage, prologueCluster);
222226 }
223227 }
224228
@@ -341,6 +345,11 @@ class TritonGPULoopSchedulingPass
341345 // only for loops missing the latency information.
342346 DenseMap<Operation *, int > opLatency =
343347 assignLatencies (getOperation (), numStages);
348+ LLVM_DEBUG ({
349+ LDBG (" Assigned latencies:\n " );
350+ for (auto [op, latency] : opLatency)
351+ LDBG (" " << latency << " : " << *op);
352+ });
344353 // numStages should not be used below this point. We should know everything
345354 // based on the assigned stages
346355
0 commit comments