@@ -226,10 +226,15 @@ initSchedule(int maxDist, int stages[SCHED_SIZE], int numStages,
226
226
return success ();
227
227
}
228
228
229
- void createAndScheduleAsyncCopy (
230
- tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
231
- tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
232
- const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
229
+ struct AsyncCopyChainOps {
230
+ ttg::AsyncCopyGlobalToLocalOp copyOp;
231
+ ttg::AsyncCommitGroupOp commitOp;
232
+ ttg::AsyncWaitOp waitOp;
233
+ ttg::LocalLoadOp localLoadOp;
234
+ };
235
+
236
+ AsyncCopyChainOps createAsyncCopy (tt::LoadOp loadOp, Value alloc,
237
+ Value extractIdx, scf::ForOp forOp) {
233
238
OpBuilder builder (loadOp);
234
239
Location loc = loadOp.getLoc ();
235
240
@@ -274,9 +279,15 @@ void createAndScheduleAsyncCopy(
274
279
auto sharedLoad =
275
280
builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad, waitOp);
276
281
282
+ return {copyOp, commitOp, waitOp, sharedLoad};
283
+ }
284
+
285
+ void scheduleAsyncCopy (
286
+ const AsyncCopyChainOps &asyncOps, tt::LoadOp loadOp,
287
+ tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
288
+ const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
289
+ auto [copyOp, commitOp, waitOp, localLoadOp] = asyncOps;
277
290
auto [loadStage, loadCluster] = schedule[loadOp];
278
- schedule.erase (loadOp);
279
- // Schedule new ops
280
291
schedule.insert (copyOp, loadStage, loadCluster);
281
292
// Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the
282
293
// later UpdateAsyncWaitCount pass can deduce better waitcnts
@@ -292,25 +303,41 @@ void createAndScheduleAsyncCopy(
292
303
clusters[SCHED_ASYNC_WAIT]);
293
304
294
305
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
295
- schedule.insert (sharedLoad , stages[SCHED_LOCAL_LOAD],
306
+ schedule.insert (localLoadOp , stages[SCHED_LOCAL_LOAD],
296
307
clusters[SCHED_LOCAL_LOAD]);
297
308
298
- loadOp->replaceAllUsesWith (ValueRange{sharedLoad});
299
309
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] &&
300
- sharedLoad ->hasOneUse ()) {
310
+ localLoadOp ->hasOneUse ()) {
301
311
if (auto cvt =
302
- dyn_cast<ttg::ConvertLayoutOp>(*sharedLoad ->getUsers ().begin ()))
312
+ dyn_cast<ttg::ConvertLayoutOp>(*localLoadOp ->getUsers ().begin ()))
303
313
schedule.insert (cvt, stages[SCHED_LOCAL_LOAD],
304
314
clusters[SCHED_LOCAL_LOAD]);
305
315
}
306
-
307
- loadOp.erase ();
308
316
}
309
317
310
- void createAndScheduleStreamCopy (
318
+ void createAndScheduleAsyncCopy (
311
319
tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
312
320
tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
313
321
const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
322
+
323
+ auto asyncOps = createAsyncCopy (loadOp, alloc, extractIdx, forOp);
324
+ loadOp->replaceAllUsesWith (ValueRange{asyncOps.localLoadOp });
325
+
326
+ scheduleAsyncCopy (asyncOps, loadOp, schedule, stages, clusters);
327
+
328
+ schedule.erase (loadOp);
329
+ loadOp.erase ();
330
+ }
331
+
332
+ struct StreamCopyChainOps {
333
+ tt::LoadOp copyOp;
334
+ ttg::MemDescSubviewOp subviewOp;
335
+ ttg::LocalStoreOp localStoreOp;
336
+ ttg::LocalLoadOp localLoadOp;
337
+ };
338
+
339
+ StreamCopyChainOps createStreamCopy (tt::LoadOp loadOp, Value alloc,
340
+ Value extractIdx, scf::ForOp forOp) {
314
341
OpBuilder builder (forOp);
315
342
Value zero = builder.create <arith::ConstantIntOp>(forOp.getLoc (), 0 , 32 );
316
343
// Replace the load with insert/extract slice.
@@ -319,11 +346,7 @@ void createAndScheduleStreamCopy(
319
346
320
347
ttg::MemDescType allocTy = cast<ttg::MemDescType>(alloc.getType ());
321
348
SmallVector<Value> copyOffsets (allocTy.getRank (), zero);
322
- Operation *copy = builder.clone (*loadOp);
323
-
324
- auto [stage, cluster] = schedule[loadOp];
325
- schedule.erase (loadOp);
326
- schedule.insert (copy, stage, cluster);
349
+ tt::LoadOp copy = cast<tt::LoadOp>(builder.clone (*loadOp));
327
350
328
351
// Extract part.
329
352
SmallVector<Value> loadOffsets (allocTy.getRank (), zero);
@@ -332,43 +355,66 @@ void createAndScheduleStreamCopy(
332
355
auto subviewTy = ttg::MemDescType::get (
333
356
allocTy.getShape ().drop_front (), allocTy.getElementType (),
334
357
allocTy.getEncoding (), sharedMemorySpace, /* mutableMemory=*/ true );
335
- auto viewLoad =
358
+ auto subviewOp =
336
359
builder.create <ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
337
360
// Clean up old local caches.
338
361
SmallVector<ttg::LocalAllocOp> allocsToErase;
339
362
for (Operation *user : loadOp->getUsers ()) {
340
363
if (auto userAlloc = dyn_cast<ttg::LocalAllocOp>(user)) {
341
- tt::replaceUsesAndPropagateType (builder, userAlloc, viewLoad.getResult ());
364
+ tt::replaceUsesAndPropagateType (builder, userAlloc,
365
+ subviewOp.getResult ());
342
366
allocsToErase.push_back (userAlloc);
343
367
}
344
368
}
345
369
for (auto allocToErase : allocsToErase)
346
370
allocToErase.erase ();
347
371
348
372
// Prefetch load ahead of the dot stage if is used by the dot.
349
- auto storeOp =
350
- builder.create <ttg::LocalStoreOp>(loc, copy->getResult (0 ), viewLoad);
351
- schedule.insert (viewLoad, stages[SCHED_LOCAL_STORE],
373
+ auto storeOp = builder.create <ttg::LocalStoreOp>(loc, copy, subviewOp);
374
+
375
+ auto sharedLoad =
376
+ builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), subviewOp);
377
+
378
+ return {copy, subviewOp, storeOp, sharedLoad};
379
+ }
380
+
381
+ void scheduleStreamCopy (
382
+ const StreamCopyChainOps &streamOps, tt::LoadOp loadOp,
383
+ tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
384
+ const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
385
+ auto [copyOp, subviewOp, localStoreOp, localLoadOp] = streamOps;
386
+ auto [stage, cluster] = schedule[loadOp];
387
+ schedule.insert (copyOp, stage, cluster);
388
+
389
+ schedule.insert (subviewOp, stages[SCHED_LOCAL_STORE],
352
390
clusters[SCHED_LOCAL_STORE]);
353
- schedule.insert (storeOp , stages[SCHED_LOCAL_STORE],
391
+ schedule.insert (localStoreOp , stages[SCHED_LOCAL_STORE],
354
392
clusters[SCHED_LOCAL_STORE]);
355
393
356
- // Create local load
357
- auto sharedLoad =
358
- builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad);
359
- Value result = sharedLoad.getResult ();
360
394
if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE])
361
- schedule.insert (sharedLoad , stages[SCHED_LOCAL_LOAD],
395
+ schedule.insert (localLoadOp , stages[SCHED_LOCAL_LOAD],
362
396
clusters[SCHED_LOCAL_LOAD]);
363
397
364
- loadOp-> replaceAllUsesWith (ValueRange{result});
365
-
366
- if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && result. hasOneUse ()) {
367
- if ( auto cvt = dyn_cast<ttg::ConvertLayoutOp>(*result. getUsers ().begin ()))
398
+ if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] &&
399
+ localLoadOp-> hasOneUse ()) {
400
+ if ( auto cvt =
401
+ dyn_cast<ttg::ConvertLayoutOp>(*localLoadOp-> getUsers ().begin ()))
368
402
schedule.insert (cvt, stages[SCHED_LOCAL_LOAD],
369
403
clusters[SCHED_LOCAL_LOAD]);
370
404
}
405
+ }
406
+
407
+ void createAndScheduleStreamCopy (
408
+ tt::LoadOp loadOp, Value alloc, Value extractIdx, scf::ForOp forOp,
409
+ tt::CoarseSchedule &schedule, const int stages[SCHED_SIZE],
410
+ const std::array<tt::CoarseSchedule::Cluster, SCHED_SIZE> &clusters) {
371
411
412
+ auto streamOps = createStreamCopy (loadOp, alloc, extractIdx, forOp);
413
+ loadOp->replaceAllUsesWith (ValueRange{streamOps.localLoadOp });
414
+
415
+ scheduleStreamCopy (streamOps, loadOp, schedule, stages, clusters);
416
+
417
+ schedule.erase (loadOp);
372
418
loadOp.erase ();
373
419
}
374
420
0 commit comments