@@ -154,9 +154,17 @@ using LoadToStreamOpMap = llvm::MapVector<Operation *, StreamOpVariant>;
154
154
// WARNING: Changing the order of schedule.clusters.newAtBack() calls
155
155
// can cause invalid schedules to be produced.
156
156
LogicalResult initSchedule (int maxDist, StreamStages &stages, int numStages,
157
- int &numBuffers, bool useAsyncCopy,
157
+ int &numBuffers, int globalPrefetch,
158
+ int localPrefetch, bool useAsyncCopy,
158
159
StreamClusters &clusters,
159
160
tt::CoarseSchedule &schedule) {
161
+ int lastStage = numStages - 1 ;
162
+ stages[SCHED_GLOBAL_LOAD] = 0 ;
163
+ stages[SCHED_LOCAL_STORE] = globalPrefetch;
164
+ stages[SCHED_LOCAL_LOAD] = lastStage - localPrefetch;
165
+ stages[SCHED_COMPUTE] = lastStage;
166
+ stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD];
167
+
160
168
bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0 ;
161
169
stages[SCHED_LOCAL_STORE] += maxDist;
162
170
@@ -607,6 +615,7 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo,
607
615
int globalPrefetch, int localPrefetch, bool useAsyncCopy,
608
616
triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
609
617
tt::CoarseSchedule schedule (numStages);
618
+ StreamStages stages;
610
619
StreamClusters clusters;
611
620
612
621
auto dumpSchedule = [&](llvm::StringRef msg) {
@@ -617,22 +626,15 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo,
617
626
});
618
627
};
619
628
620
- int numBuffers = 1 ;
621
629
int maxDist = 0 ;
622
630
for (auto &[l, info] : loadToInfo) {
623
631
maxDist = std::max (maxDist, info.distToUse );
624
632
}
625
633
626
- int lastStage = numStages - 1 ;
627
- StreamStages stages;
628
- stages[SCHED_GLOBAL_LOAD] = 0 ;
629
- stages[SCHED_LOCAL_STORE] = globalPrefetch;
630
- stages[SCHED_LOCAL_LOAD] = lastStage - localPrefetch;
631
- stages[SCHED_COMPUTE] = lastStage;
632
- stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD];
633
-
634
- if (failed (initSchedule (maxDist, stages, numStages, numBuffers, useAsyncCopy,
635
- clusters, schedule)))
634
+ int numBuffers = 1 ;
635
+ if (failed (initSchedule (maxDist, stages, numStages, numBuffers,
636
+ globalPrefetch, localPrefetch, useAsyncCopy, clusters,
637
+ schedule)))
636
638
return {};
637
639
638
640
if (failed (scheduleLoads (loadToInfo, maxDist, numStages, stages, clusters,
0 commit comments