Skip to content

Commit 9d29969

Browse files
authored
[AMD][NFC] Consolidate initialization in initSchedule for pipeliner (#7556)
Moves all initializations of stages to `initSchedule`. Missed this one in the last PRs.
1 parent 2edb2e7 commit 9d29969

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,17 @@ using LoadToStreamOpMap = llvm::MapVector<Operation *, StreamOpVariant>;
154154
// WARNING: Changing the order of schedule.clusters.newAtBack() calls
155155
// can cause invalid schedules to be produced.
156156
LogicalResult initSchedule(int maxDist, StreamStages &stages, int numStages,
157-
int &numBuffers, bool useAsyncCopy,
157+
int &numBuffers, int globalPrefetch,
158+
int localPrefetch, bool useAsyncCopy,
158159
StreamClusters &clusters,
159160
tt::CoarseSchedule &schedule) {
161+
int lastStage = numStages - 1;
162+
stages[SCHED_GLOBAL_LOAD] = 0;
163+
stages[SCHED_LOCAL_STORE] = globalPrefetch;
164+
stages[SCHED_LOCAL_LOAD] = lastStage - localPrefetch;
165+
stages[SCHED_COMPUTE] = lastStage;
166+
stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD];
167+
160168
bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0;
161169
stages[SCHED_LOCAL_STORE] += maxDist;
162170

@@ -607,6 +615,7 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo,
607615
int globalPrefetch, int localPrefetch, bool useAsyncCopy,
608616
triton::AMD::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
609617
tt::CoarseSchedule schedule(numStages);
618+
StreamStages stages;
610619
StreamClusters clusters;
611620

612621
auto dumpSchedule = [&](llvm::StringRef msg) {
@@ -617,22 +626,15 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo,
617626
});
618627
};
619628

620-
int numBuffers = 1;
621629
int maxDist = 0;
622630
for (auto &[l, info] : loadToInfo) {
623631
maxDist = std::max(maxDist, info.distToUse);
624632
}
625633

626-
int lastStage = numStages - 1;
627-
StreamStages stages;
628-
stages[SCHED_GLOBAL_LOAD] = 0;
629-
stages[SCHED_LOCAL_STORE] = globalPrefetch;
630-
stages[SCHED_LOCAL_LOAD] = lastStage - localPrefetch;
631-
stages[SCHED_COMPUTE] = lastStage;
632-
stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD];
633-
634-
if (failed(initSchedule(maxDist, stages, numStages, numBuffers, useAsyncCopy,
635-
clusters, schedule)))
634+
int numBuffers = 1;
635+
if (failed(initSchedule(maxDist, stages, numStages, numBuffers,
636+
globalPrefetch, localPrefetch, useAsyncCopy, clusters,
637+
schedule)))
636638
return {};
637639

638640
if (failed(scheduleLoads(loadToInfo, maxDist, numStages, stages, clusters,

0 commit comments

Comments
 (0)