Skip to content

Commit d183197

Browse files
authored
[AMD] Refactor StreamPipeliner to use more common functions (#7526)
Further refactoring of Streampipeliner.cpp to use more common pipeliner functionality: `triton::createAllocation`, `triton::createSingleBufferView`, `triton::replaceWithSharedLoad` and a bit of general cleanup. Overall NFC except: - The order of LocalDealloc is reversed now - The memdesc of the subview additionally includes the allocSize Also we had no lit test checking that the LocalLoad consumes the AsyncToken so I adjusted one to include the check.
1 parent b64e85b commit d183197

File tree

5 files changed

+117
-203
lines changed

5 files changed

+117
-203
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,12 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
260260

261261
/// Replace all uses of `old` with a local load from `alloc` unless the use is a
262262
/// `ttg.local_alloc` with a matching shared encoding, in which case the shared
263-
/// memory is forwarded directly into the use.
264-
void replaceUsesWithLocalLoad(
265-
OpBuilder &builder, OpResult old,
266-
TypedValue<triton::gpu::MemDescType> alloc,
267-
TypedValue<triton::gpu::AsyncTokenType> token = {});
263+
/// memory is forwarded directly into the use. Returns the `ttg.local_load` if
264+
/// it created one.
265+
triton::gpu::LocalLoadOp
266+
replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
267+
TypedValue<triton::gpu::MemDescType> alloc,
268+
TypedValue<triton::gpu::AsyncTokenType> token = {});
268269

269270
// Return true if the value comes from a load or a block argument.
270271
// This will skip convert layouts and memdesc views.

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,9 +1532,10 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
15321532
op->erase();
15331533
}
15341534

1535-
void replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
1536-
TypedValue<ttg::MemDescType> alloc,
1537-
TypedValue<ttg::AsyncTokenType> token) {
1535+
ttg::LocalLoadOp
1536+
replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
1537+
TypedValue<ttg::MemDescType> alloc,
1538+
TypedValue<ttg::AsyncTokenType> token) {
15381539
// Remove redundant local_load -> local_alloc
15391540
auto allocTy = alloc.getType();
15401541
SmallVector<ttg::LocalAllocOp> allocsToErase;
@@ -1549,16 +1550,18 @@ void replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
15491550

15501551
// If there are some uses that were not local_allocs, we need to create a
15511552
// local_load for them.
1553+
ttg::LocalLoadOp maybeLocalLoad;
15521554
if (std::distance(old.getUsers().begin(), old.getUsers().end()) >
15531555
allocsToErase.size()) {
15541556
auto loc = old.getOwner()->getLoc();
1555-
auto sharedLoad = builder.template create<ttg::LocalLoadOp>(
1557+
maybeLocalLoad = builder.template create<ttg::LocalLoadOp>(
15561558
loc, old.getType(), alloc, token);
1557-
old.replaceAllUsesWith(sharedLoad.getResult());
1559+
old.replaceAllUsesWith(maybeLocalLoad);
15581560
}
15591561
for (auto alloc : allocsToErase) {
15601562
alloc.erase();
15611563
}
1564+
return maybeLocalLoad;
15621565
}
15631566

15641567
bool comesFromLoadOrBlockArg(Value v) {

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
250250
// Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced
251251
// COMMON-LABEL: loop_with_dot_and_transpose
252252
// COMMON: ttg.local_alloc {{.*}}, mutable>
253-
// COMMON: ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable>
253+
// COMMON: ttg.memdesc_trans {{.*}}, mutable, {{.*}} -> {{.*}}, mutable
254254

255255
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
256256
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
@@ -501,9 +501,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
501501
//
502502
// ASYNC: ttg.async_wait
503503
// ASYNC: ttg.async_copy_global_to_local
504-
// ASYNC: ttg.local_load
504+
// ASYNC: ttg.local_load {{.*}} token
505505
// ASYNC: ttg.async_copy_global_to_local
506-
// ASYNC: ttg.local_load
506+
// ASYNC: ttg.local_load {{.*}} token
507507
// ASYNC: ttg.dot
508508

509509
// Epilogue

test/TritonGPU/loop-pipeline.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,8 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
462462
// AMD: scf.yield %[[SELECT_33]]
463463
// AMD: }
464464
// AMD: %[[SELECT_37:.*]] = arith.select %[[CMPI_29]], %[[IF_36]], %[[SELECT_33]]
465-
// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_0]]
466-
// AMD: ttg.local_dealloc %[[LOCAL_ALLOC_1]]
465+
// AMD-DAG: ttg.local_dealloc %[[LOCAL_ALLOC_0]]
466+
// AMD-DAG: ttg.local_dealloc %[[LOCAL_ALLOC_1]]
467467
tt.func @indirect_bmm_scalar(%77: i64 {tt.divisibility=16: i32},
468468
%76: index,
469469
%49: tensor<16x16x!tt.ptr<f16>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},

0 commit comments

Comments
 (0)