@@ -176,6 +176,72 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
176176
177177// -----
178178
179+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
180+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ]}>
181+ #smem = #ttg.shared_memory
182+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
183+ // CHECK-LABEL: async_copy_swizzled_mask_other
184+ tt.func public @async_copy_swizzled_mask_other (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
185+ %arg1: i32 {tt.divisibility = 16 : i32 },
186+ %arg2: !ttg.memdesc <32 x32 xf16 , #shared , #smem , mutable >,
187+ %arg3: i32 {tt.divisibility = 16 : i32 }) {
188+ // We need the splat to allow the AxisAnalysis to work during lowering
189+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf16 , #blocked >
190+ %c0_i32 = arith.constant 0 : i32
191+ %c32_i32 = arith.constant 32 : i32
192+ %c31_i32 = arith.constant 31 : i32
193+ %1 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <32 x32 x!tt.ptr <f16 >, #blocked >
194+ %29 = arith.addi %arg3 , %c31_i32 : i32
195+ %30 = arith.divsi %29 , %c32_i32 : i32
196+ %31 = arith.cmpi sgt , %30 , %c0_i32 : i32
197+
198+ %51 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
199+ %52 = tt.expand_dims %51 {axis = 1 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <32 x1 xi32 , #blocked >
200+ %65 = tt.splat %arg3 : i32 -> tensor <32 x1 xi32 , #blocked >
201+ %66 = arith.cmpi slt , %52 , %65 : tensor <32 x1 xi32 , #blocked >
202+ %67 = tt.broadcast %66 : tensor <32 x1 xi1 , #blocked > -> tensor <32 x32 xi1 , #blocked >
203+
204+ %70 = tt.splat %31 : i1 -> tensor <32 x32 xi1 , #blocked >
205+ %71 = arith.andi %70 , %67 : tensor <32 x32 xi1 , #blocked >
206+
207+ // Each thread needs to load 4 elements and we load 1 (sizePerThread) per global.load.lds
208+ // Note that mask/other alignment is 1 so we need 4 conditionals
209+
210+ // CHECK: rocdl.ds_bpermute
211+ // CHECK: rocdl.ballot
212+ // CHECK: llvm.cond_br
213+ // CHECK: rocdl.global.load.lds
214+ // CHECK-NEXT: llvm.br
215+ // CHECK: _predicated_store
216+
217+ // CHECK: rocdl.ds_bpermute
218+ // CHECK: rocdl.ballot
219+ // CHECK: llvm.cond_br
220+ // CHECK: rocdl.global.load.lds
221+ // CHECK-NEXT: llvm.br
222+ // CHECK: _predicated_store
223+
224+ // CHECK: rocdl.ds_bpermute
225+ // CHECK: rocdl.ballot
226+ // CHECK: llvm.cond_br
227+ // CHECK: rocdl.global.load.lds
228+ // CHECK-NEXT: llvm.br
229+ // CHECK: _predicated_store
230+
231+ // CHECK: rocdl.ds_bpermute
232+ // CHECK: rocdl.ballot
233+ // CHECK: llvm.cond_br
234+ // CHECK: rocdl.global.load.lds
235+ // CHECK-NEXT: llvm.br
236+ // CHECK: _predicated_store
237+
238+ %2 = ttg.async_copy_global_to_local %1 , %arg2 mask %67 other %cst_0 : tensor <32 x32 x!tt.ptr <f16 >, #blocked > -> <32 x32 xf16 , #shared , #smem , mutable >
239+ tt.return
240+ }
241+ }
242+
243+ // -----
244+
179245#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [16 , 1 ], order = [1 , 0 ]}>
180246#shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
181247#smem = #ttg.shared_memory
0 commit comments