Skip to content

Commit 646d063

Browse files
authored
[AMD] Support swizzled shared encodings when lowering AsyncCopy (triton-lang#6369)
Similar to `BufferLoadToLocal` (triton-lang#6329) we can swizzle the global ptrs of `AsyncCopy` between lanes of a warp to get coalesced writes to lds. This PR mostly extracts the common code between `BufferLoadToLocal` and `AsyncCopy` and uses it for both Ops.
1 parent dacd155 commit 646d063

File tree

2 files changed

+257
-158
lines changed

2 files changed

+257
-158
lines changed

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,72 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
176176

177177
// -----
178178

179+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
180+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 2, maxPhase = 4, order = [1, 0]}>
181+
#smem = #ttg.shared_memory
182+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
183+
// CHECK-LABEL: async_copy_swizzled_mask_other
184+
tt.func public @async_copy_swizzled_mask_other(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
185+
%arg1: i32 {tt.divisibility = 16 : i32},
186+
%arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>,
187+
%arg3: i32 {tt.divisibility = 16 : i32}) {
188+
// We need the splat to allow the AxisAnalysis to work during lowering
189+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked>
190+
%c0_i32 = arith.constant 0 : i32
191+
%c32_i32 = arith.constant 32 : i32
192+
%c31_i32 = arith.constant 31 : i32
193+
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
194+
%29 = arith.addi %arg3, %c31_i32 : i32
195+
%30 = arith.divsi %29, %c32_i32 : i32
196+
%31 = arith.cmpi sgt, %30, %c0_i32 : i32
197+
198+
%51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
199+
%52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
200+
%65 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked>
201+
%66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked>
202+
%67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
203+
204+
%70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked>
205+
%71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked>
206+
207+
// Each thread needs to load 4 elements and we load 1 (sizePerThread) per global.load.lds
208+
// Note that mask/other alignment is 1 so we need 4 conditionals
209+
210+
// CHECK: rocdl.ds_bpermute
211+
// CHECK: rocdl.ballot
212+
// CHECK: llvm.cond_br
213+
// CHECK: rocdl.global.load.lds
214+
// CHECK-NEXT: llvm.br
215+
// CHECK: _predicated_store
216+
217+
// CHECK: rocdl.ds_bpermute
218+
// CHECK: rocdl.ballot
219+
// CHECK: llvm.cond_br
220+
// CHECK: rocdl.global.load.lds
221+
// CHECK-NEXT: llvm.br
222+
// CHECK: _predicated_store
223+
224+
// CHECK: rocdl.ds_bpermute
225+
// CHECK: rocdl.ballot
226+
// CHECK: llvm.cond_br
227+
// CHECK: rocdl.global.load.lds
228+
// CHECK-NEXT: llvm.br
229+
// CHECK: _predicated_store
230+
231+
// CHECK: rocdl.ds_bpermute
232+
// CHECK: rocdl.ballot
233+
// CHECK: llvm.cond_br
234+
// CHECK: rocdl.global.load.lds
235+
// CHECK-NEXT: llvm.br
236+
// CHECK: _predicated_store
237+
238+
%2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
239+
tt.return
240+
}
241+
}
242+
243+
// -----
244+
179245
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}>
180246
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
181247
#smem = #ttg.shared_memory

0 commit comments

Comments
 (0)