Skip to content

Commit c4452b2

Browse files
authored
[AMD] Initial support for AsyncCopyGlobalToLocal in StreamPipeliner (#6270)
Adds limited support to replace `tt.load` with `ttg.asyc_copy_global_to_local` in the stream pipeliner guarded behind `TRITON_HIP_USE_ASYNC_COPY` . Loads which result in an actual swizzled shared encoding are ignored because they are not supported in the lowering to `llir`. The swizzling support and performance enablement will be done through separate PRs. A simple matmul lit test is added to explicitly check for AsyncCopies. Most other tests do not use AsyncCopy (yet) because the shared layout is swizzled.
1 parent 6276a78 commit c4452b2

File tree

7 files changed

+249
-58
lines changed

7 files changed

+249
-58
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3333
"TRITON_ENABLE_LLVM_DEBUG",
3434
"TRITON_HIP_GLOBAL_PREFETCH",
3535
"TRITON_HIP_LOCAL_PREFETCH",
36+
"TRITON_HIP_USE_ASYNC_COPY",
3637
"TRITON_HIP_USE_BLOCK_PINGPONG",
3738
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",
3839
"TRITON_LLVM_DEBUG_ONLY",

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 122 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,SYNC
2+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2 use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC
23

34
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
45
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -7,7 +8,7 @@
78
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
89
#smem = #ttg.shared_memory
910
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
10-
// CHECK-LABEL: tt.func @load_two_users
11+
// COMMON-LABEL: tt.func @load_two_users
1112
tt.func @load_two_users(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) {
1213
%cst = arith.constant dense<0> : tensor<1x16xi32, #blocked>
1314
%cst_0 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>
@@ -34,13 +35,13 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
3435
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
3536
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
3637
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
37-
// CHECK: ttg.local_store
38-
// CHECK: scf.for
39-
// CHECK: tt.load
40-
// CHECK: tt.dot
41-
// CHECK: tt.dot
42-
// CHECK: ttg.local_store
43-
// CHECK: scf.yield
38+
// COMMON: ttg.local_store
39+
// COMMON: scf.for
40+
// COMMON: tt.load
41+
// COMMON: tt.dot
42+
// COMMON: tt.dot
43+
// COMMON: ttg.local_store
44+
// COMMON: scf.yield
4445
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
4546
%18 = tt.load %16 : tensor<64x16x!tt.ptr<f16>, #blocked>
4647
%19 = ttg.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
@@ -60,8 +61,8 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
6061

6162
// -----
6263

63-
// CHECK-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de
64-
// CHECK-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1>
64+
// COMMON-LABEL: tt.func public @_jagged_hstu_attn_fwd_0d1d2d3d4d5de
65+
// COMMON-NOT: ttg.convert_layout {{.*}} : tensor<32x64xf32, #shared> -> tensor<32x64xf32, #shared1>
6566

6667
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
6768
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
@@ -166,13 +167,13 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
166167

167168
// Disable pipelining for loops that contain barrier.
168169
// Barriers are problematic since they are not chained to any other operation.
169-
// CHECK-LABEL: tt.func public @add_barrier_kernel
170-
// CHECK: scf.for
171-
// CHECK: tt.load
172-
// CHECK: gpu.barrier
173-
// CHECK: tt.store
174-
// CHECK-NOT: gpu.barrier
175-
// CHECK: tt.return
170+
// COMMON-LABEL: tt.func public @add_barrier_kernel
171+
// COMMON: scf.for
172+
// COMMON: tt.load
173+
// COMMON: gpu.barrier
174+
// COMMON: tt.store
175+
// COMMON-NOT: gpu.barrier
176+
// COMMON: tt.return
176177

177178
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
178179
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
@@ -203,11 +204,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
203204

204205
// -----
205206

206-
// CHECK-NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
207-
// CHECK: #ttg.swizzled_shared<{{.*}} order = [2, 1, 0]
208-
// CHECK-NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
207+
// COMMON-NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
208+
// COMMON: #ttg.swizzled_shared<{{.*}} order = [2, 1, 0]
209+
// COMMON-NOT: #ttg.swizzled_shared<{{.*}} order = [2, 0, 1]
209210

210-
// CHECK-LABEL: tt.func public @slowest_dim_is_batch
211+
// COMMON-LABEL: tt.func public @slowest_dim_is_batch
211212
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [4, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
212213
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [16, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
213214
#blocked2 = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -239,9 +240,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
239240
// -----
240241

241242
// Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced
242-
// CHECK-LABEL: loop_with_dot_and_transpose
243-
// CHECK: ttg.local_alloc {{.*}}, mutable>
244-
// CHECK: ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable>
243+
// COMMON-LABEL: loop_with_dot_and_transpose
244+
// COMMON: ttg.local_alloc {{.*}}, mutable>
245+
// COMMON: ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable>
245246

246247
#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
247248
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
@@ -270,11 +271,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
270271
// -----
271272

272273
// Check that the stream pipeliner updates atomic op in the k-loop correctly
273-
// CHECK-LABEL: _triton_gemm_kernel_atomic_rmw
274-
// CHECK: scf.for
275-
// CHECK: tt.atomic_rmw fadd, acq_rel, gpu
276-
// CHECK: tt.dot
277-
// CHECK: scf.yield
274+
// COMMON-LABEL: _triton_gemm_kernel_atomic_rmw
275+
// COMMON: scf.for
276+
// COMMON: tt.atomic_rmw fadd, acq_rel, gpu
277+
// COMMON: tt.dot
278+
// COMMON: scf.yield
278279

279280
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
280281
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
@@ -338,25 +339,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
338339
// -----
339340

340341
// Check that we can pipeline scaled dot with linear layout
341-
// CHECK-LABEL: mxfp8_mxfp4_matmul
342+
// COMMON-LABEL: mxfp8_mxfp4_matmul
342343

343344
// Prologue
344-
// CHECK-COUNT-3: ttg.local_alloc
345-
// CHECK-COUNT-3: tt.load
346-
// CHECK-COUNT-3: ttg.local_store
345+
// COMMON-COUNT-3: ttg.local_alloc
346+
// COMMON-COUNT-3: tt.load
347+
// COMMON-COUNT-3: ttg.local_store
347348

348349
// Main loop
349-
// CHECK: scf.for
350-
// CHECK-COUNT-3: ttg.local_load
351-
// CHECK: tt.dot_scaled
352-
// CHECK: scf.yield
350+
// COMMON: scf.for
351+
// COMMON-COUNT-3: ttg.local_load
352+
// COMMON: tt.dot_scaled
353+
// COMMON: scf.yield
353354

354355
// Epilogue
355-
// CHECK-COUNT-3: ttg.local_load
356-
// CHECK: scf.if
357-
// CHECK: tt.dot_scaled
358-
// CHECK-COUNT-2: scf.yield
359-
// CHECK-COUNT-3: ttg.local_dealloc
356+
// COMMON-COUNT-3: ttg.local_load
357+
// COMMON: scf.if
358+
// COMMON: tt.dot_scaled
359+
// COMMON-COUNT-2: scf.yield
360+
// COMMON-COUNT-3: ttg.local_dealloc
360361

361362
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
362363
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
@@ -464,3 +465,81 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
464465
tt.return
465466
}
466467
}
468+
469+
// -----
470+
471+
// Check that we can pipeline a simple matmul kernel
472+
// Note: Currently AsyncCopy is only used for the second operand because we do not support (actual) swizzled shared encodings
473+
474+
// COMMON-LABEL: simple_matmul_kernel
475+
476+
// Prologue
477+
// COMMON-COUNT-2: ttg.local_alloc
478+
// SYNC-COUNT-2: tt.load
479+
// SYNC-COUNT-2: ttg.local_store
480+
//
481+
// ASYNC: tt.load
482+
// ASYNC: ttg.local_store
483+
// ASYNC: ttg.async_copy_global_to_local
484+
485+
// Main loop
486+
// COMMON: scf.for
487+
//
488+
// SYNC-COUNT-2: ttg.local_load
489+
// SYNC: tt.dot
490+
// SYNC: scf.yield
491+
//
492+
// ASYNC: ttg.local_load
493+
// ASYNC: ttg.async_wait
494+
// ASYNC: ttg.local_load
495+
// ASYNC: ttg.dot
496+
// ASYNC: ttg.async_copy_global_to_local
497+
498+
// Epilogue
499+
// COMMON-COUNT-2: ttg.local_load
500+
// COMMON: scf.if
501+
// COMMON: tt.dot
502+
// COMMON-COUNT-2: scf.yield
503+
// COMMON-COUNT-2: ttg.local_dealloc
504+
505+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
506+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
507+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [32, 32], isTransposed = true}>
508+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
509+
tt.func public @simple_matmul_kernel(%test: tensor<1x64xi32, #blocked1>, %arg0: tensor<64x64x!tt.ptr<f16>, #mma>, %arg1: i32, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
510+
%cst = arith.constant dense<32> : tensor<64x32xi32, #blocked>
511+
%cst_0 = arith.constant dense<32> : tensor<32x64xi32, #blocked1>
512+
%c64_i32 = arith.constant 64 : i32
513+
%c1_i32 = arith.constant 1 : i32
514+
%c0_i32 = arith.constant 0 : i32
515+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
516+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
517+
%1 = arith.muli %arg1, %c64_i32 : i32
518+
%2 = tt.splat %1 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
519+
%3 = arith.addi %2, %0 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
520+
%4 = tt.splat %arg6 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
521+
%5 = arith.remsi %3, %4 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
522+
%6 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
523+
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
524+
%8 = tt.broadcast %7 : tensor<1x32xi32, #blocked> -> tensor<64x32xi32, #blocked>
525+
%9 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<64x32x!tt.ptr<f16>, #blocked>
526+
%10 = tt.addptr %9, %8 : tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<64x32xi32, #blocked>
527+
%11 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
528+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<32x64xi32, #blocked1>
529+
%13 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked1>
530+
%14 = tt.addptr %13, %12 : tensor<32x64x!tt.ptr<f16>, #blocked1>, tensor<32x64xi32, #blocked1>
531+
%15:3 = scf.for %arg11 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg12 = %cst_1, %arg13 = %10, %arg14 = %14) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<32x64x!tt.ptr<f16>, #blocked1>) : i32 {
532+
%17 = tt.load %arg13 : tensor<64x32x!tt.ptr<f16>, #blocked>
533+
%18 = tt.load %arg14 : tensor<32x64x!tt.ptr<f16>, #blocked1>
534+
%19 = ttg.convert_layout %17 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
535+
%20 = ttg.convert_layout %18 : tensor<32x64xf16, #blocked1> -> tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
536+
%21 = tt.dot %19, %20, %arg12, inputPrecision = tf32 : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma>
537+
%22 = tt.addptr %arg13, %cst : tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<64x32xi32, #blocked>
538+
%23 = tt.addptr %arg14, %cst_0 : tensor<32x64x!tt.ptr<f16>, #blocked1>, tensor<32x64xi32, #blocked1>
539+
scf.yield %21, %22, %23 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr<f16>, #blocked>, tensor<32x64x!tt.ptr<f16>, #blocked1>
540+
}
541+
%16 = arith.truncf %15#0 : tensor<64x64xf32, #mma> to tensor<64x64xf16, #mma>
542+
tt.store %arg0, %16 : tensor<64x64x!tt.ptr<f16>, #mma>
543+
tt.return
544+
}
545+
}

third_party/amd/backend/compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def make_ttgir(mod, metadata, options):
239239

240240
global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
241241
local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
242+
use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1
242243

243244
# The `local-prefetch` scheduling variant requires turning on buffer ops.
244245
if options.instruction_sched_variant == "local-prefetch":
@@ -250,7 +251,10 @@ def make_ttgir(mod, metadata, options):
250251
"num_stages == 0. Now it will not happen anymore; "
251252
"please update to use num_stages == 2 for "
252253
"equivalent behavior in the past.")
253-
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch)
254+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch,
255+
use_async_copy)
256+
if use_async_copy:
257+
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
254258
passes.common.add_canonicalizer(pm)
255259
if options.instruction_sched_variant.lower() != "none":
256260
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ namespace mlir {
1010

1111
std::unique_ptr<Pass>
1212
createTritonAMDGPUStreamPipelinePass(int numStages = 2, int globalPrefetch = 0,
13-
int localPrefetch = 0);
13+
int localPrefetch = 0,
14+
bool useAsyncCopy = false);
1415

1516
std::unique_ptr<Pass>
1617
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod
2525
Option<"localPrefetch", "local_prefetch",
2626
"int32_t", /*default*/"0",
2727
"Set local prefetch stage count">,
28+
Option<"useAsyncCopy", "use_async_copy",
29+
"bool", /*default*/"false",
30+
"Use AsyncCopyGlobalToLocal to directly load to shared memory">,
2831
];
2932
}
3033

0 commit comments

Comments
 (0)