Skip to content

Commit ac0d4db

Browse files
authored
[AMD][NFC] Split createAndSchedule* in stream pipeliner(#7514)
Splits `createAndScheduleAsyncCopy` and `createAndScheduleStreamCopy` to make it reusable if we want to schedule the ops differently in a future PR.
1 parent b7a0502 commit ac0d4db

File tree

1 file changed

+79
-33
lines changed

1 file changed

+79
-33
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,15 @@ initSchedule(int maxDist, int stages[SCHED_SIZE], int numStages,
226226
return success();
227227
}
228228

229-
void createAndScheduleAsyncCopy(
230-
tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
231-
tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
232-
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
229+
struct AsyncCopyChainOps {
230+
ttg::AsyncCopyGlobalToLocalOp copyOp;
231+
ttg::AsyncCommitGroupOp commitOp;
232+
ttg::AsyncWaitOp waitOp;
233+
ttg::LocalLoadOp localLoadOp;
234+
};
235+
236+
AsyncCopyChainOps createAsyncCopy(tt::LoadOp loadOp, Value alloc,
237+
Value extractIdx, scf::ForOp forOp) {
233238
OpBuilder builder(loadOp);
234239
Location loc = loadOp.getLoc();
235240

@@ -274,9 +279,15 @@ void createAndScheduleAsyncCopy(
274279
auto sharedLoad =
275280
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), viewLoad, waitOp);
276281

282+
return {copyOp, commitOp, waitOp, sharedLoad};
283+
}
284+
285+
void scheduleAsyncCopy(
286+
const AsyncCopyChainOps &asyncOps, tt::LoadOp loadOp,
287+
tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
288+
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
289+
auto [copyOp, commitOp, waitOp, localLoadOp] = asyncOps;
277290
auto [loadStage, loadCluster] = schedule[loadOp];
278-
schedule.erase(loadOp);
279-
// Schedule new ops
280291
schedule.insert(copyOp, loadStage, loadCluster);
281292
// Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the
282293
// later UpdateAsyncWaitCount pass can deduce better waitcnts
@@ -292,25 +303,41 @@ void createAndScheduleAsyncCopy(
292303
clusters[SCHED_ASYNC_WAIT]);
293304

294305
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
295-
schedule.insert(sharedLoad, stages[SCHED_LOCAL_LOAD],
306+
schedule.insert(localLoadOp, stages[SCHED_LOCAL_LOAD],
296307
clusters[SCHED_LOCAL_LOAD]);
297308

298-
loadOp->replaceAllUsesWith(ValueRange{sharedLoad});
299309
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] &&
300-
sharedLoad->hasOneUse()) {
310+
localLoadOp->hasOneUse()) {
301311
if (auto cvt =
302-
dyn_cast<ttg::ConvertLayoutOp>(*sharedLoad->getUsers().begin()))
312+
dyn_cast<ttg::ConvertLayoutOp>(*localLoadOp->getUsers().begin()))
303313
schedule.insert(cvt, stages[SCHED_LOCAL_LOAD],
304314
clusters[SCHED_LOCAL_LOAD]);
305315
}
306-
307-
loadOp.erase();
308316
}
309317

310-
void createAndScheduleStreamCopy(
318+
void createAndScheduleAsyncCopy(
311319
tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
312320
tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
313321
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
322+
323+
auto asyncOps = createAsyncCopy(loadOp, alloc, extractIdx, forOp);
324+
loadOp->replaceAllUsesWith(ValueRange{asyncOps.localLoadOp});
325+
326+
scheduleAsyncCopy(asyncOps, loadOp, schedule, stages, clusters);
327+
328+
schedule.erase(loadOp);
329+
loadOp.erase();
330+
}
331+
332+
struct StreamCopyChainOps {
333+
tt::LoadOp copyOp;
334+
ttg::MemDescSubviewOp subviewOp;
335+
ttg::LocalStoreOp localStoreOp;
336+
ttg::LocalLoadOp localLoadOp;
337+
};
338+
339+
StreamCopyChainOps createStreamCopy(tt::LoadOp loadOp, Value alloc,
340+
Value extractIdx, scf::ForOp forOp) {
314341
OpBuilder builder(forOp);
315342
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
316343
// Replace the load with insert/extract slice.
@@ -319,11 +346,7 @@ void createAndScheduleStreamCopy(
319346

320347
ttg::MemDescType allocTy = cast<ttg::MemDescType>(alloc.getType());
321348
SmallVector<Value> copyOffsets(allocTy.getRank(), zero);
322-
Operation *copy = builder.clone(*loadOp);
323-
324-
auto [stage, cluster] = schedule[loadOp];
325-
schedule.erase(loadOp);
326-
schedule.insert(copy, stage, cluster);
349+
tt::LoadOp copy = cast<tt::LoadOp>(builder.clone(*loadOp));
327350

328351
// Extract part.
329352
SmallVector<Value> loadOffsets(allocTy.getRank(), zero);
@@ -332,43 +355,66 @@ void createAndScheduleStreamCopy(
332355
auto subviewTy = ttg::MemDescType::get(
333356
allocTy.getShape().drop_front(), allocTy.getElementType(),
334357
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
335-
auto viewLoad =
358+
auto subviewOp =
336359
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
337360
// Clean up old local caches.
338361
SmallVector<ttg::LocalAllocOp> allocsToErase;
339362
for (Operation *user : loadOp->getUsers()) {
340363
if (auto userAlloc = dyn_cast<ttg::LocalAllocOp>(user)) {
341-
tt::replaceUsesAndPropagateType(builder, userAlloc, viewLoad.getResult());
364+
tt::replaceUsesAndPropagateType(builder, userAlloc,
365+
subviewOp.getResult());
342366
allocsToErase.push_back(userAlloc);
343367
}
344368
}
345369
for (auto allocToErase : allocsToErase)
346370
allocToErase.erase();
347371

348372
// Prefetch load ahead of the dot stage if is used by the dot.
349-
auto storeOp =
350-
builder.create<ttg::LocalStoreOp>(loc, copy->getResult(0), viewLoad);
351-
schedule.insert(viewLoad, stages[SCHED_LOCAL_STORE],
373+
auto storeOp = builder.create<ttg::LocalStoreOp>(loc, copy, subviewOp);
374+
375+
auto sharedLoad =
376+
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), subviewOp);
377+
378+
return {copy, subviewOp, storeOp, sharedLoad};
379+
}
380+
381+
void scheduleStreamCopy(
382+
const StreamCopyChainOps &streamOps, tt::LoadOp loadOp,
383+
tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
384+
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
385+
auto [copyOp, subviewOp, localStoreOp, localLoadOp] = streamOps;
386+
auto [stage, cluster] = schedule[loadOp];
387+
schedule.insert(copyOp, stage, cluster);
388+
389+
schedule.insert(subviewOp, stages[SCHED_LOCAL_STORE],
352390
clusters[SCHED_LOCAL_STORE]);
353-
schedule.insert(storeOp, stages[SCHED_LOCAL_STORE],
391+
schedule.insert(localStoreOp, stages[SCHED_LOCAL_STORE],
354392
clusters[SCHED_LOCAL_STORE]);
355393

356-
// Create local load
357-
auto sharedLoad =
358-
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), viewLoad);
359-
Value result = sharedLoad.getResult();
360394
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
361-
schedule.insert(sharedLoad, stages[SCHED_LOCAL_LOAD],
395+
schedule.insert(localLoadOp, stages[SCHED_LOCAL_LOAD],
362396
clusters[SCHED_LOCAL_LOAD]);
363397

364-
loadOp->replaceAllUsesWith(ValueRange{result});
365-
366-
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && result.hasOneUse()) {
367-
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(*result.getUsers().begin()))
398+
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] &&
399+
localLoadOp->hasOneUse()) {
400+
if (auto cvt =
401+
dyn_cast<ttg::ConvertLayoutOp>(*localLoadOp->getUsers().begin()))
368402
schedule.insert(cvt, stages[SCHED_LOCAL_LOAD],
369403
clusters[SCHED_LOCAL_LOAD]);
370404
}
405+
}
406+
407+
void createAndScheduleStreamCopy(
408+
tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
409+
tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
410+
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
371411

412+
auto streamOps = createStreamCopy(loadOp, alloc, extractIdx, forOp);
413+
loadOp->replaceAllUsesWith(ValueRange{streamOps.localLoadOp});
414+
415+
scheduleStreamCopy(streamOps, loadOp, schedule, stages, clusters);
416+
417+
schedule.erase(loadOp);
372418
loadOp.erase();
373419
}
374420

0 commit comments

Comments
 (0)