@@ -162,3 +162,123 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
162162 tt.return
163163 }
164164}
165+
166+ // -----
167+
168+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
169+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 4 , order = [1 , 0 ]}>
170+ #smem = #ttg.shared_memory
171+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
172+ // COMMON-LABEL: buffer_load_swizzled_simple
173+ tt.func public @buffer_load_swizzled_simple (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
174+ %arg1: !tt.ptr <f16 >,
175+ %arg2: tensor <16 x64 xi32 , #blocked >,
176+ %arg3: !ttg.memdesc <16 x64 xf16 , #shared , #smem , mutable >) {
177+ // Each thread needs to load 2 elements and we load 1 (sizePerThread) per buffer load instruction
178+ // COMMON: rocdl.make.buffer.rsrc
179+ // COMMON-NOT: rocdl.make.buffer.rsrc
180+ // COMMON: rocdl.ds_bpermute
181+ // COMMON: rocdl.raw.ptr.buffer.load.lds
182+ // COMMON: rocdl.ds_bpermute
183+ // COMMON: rocdl.raw.ptr.buffer.load.lds
184+ // COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
185+ %65 = amdgpu.buffer_load_to_local %arg1 [%arg2 ] into %arg3 {OpIdx = #amdgpu.OpIdx <1 >} : <f16 >[tensor <16 x64 xi32 , #blocked >] -> <16 x64 xf16 , #shared , #smem , mutable >
186+ tt.return
187+ }
188+ }
189+
190+ // -----
191+
192+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
193+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 2 , maxPhase = 8 , order = [1 , 0 ]}>
194+ #smem = #ttg.shared_memory
195+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
196+ // COMMON-LABEL: buffer_load_to_local_swizzled_mask_other
197+ tt.func public @buffer_load_to_local_swizzled_mask_other (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
198+ %arg1: !tt.ptr <f16 >,
199+ %arg2: tensor <32 x32 xi32 , #blocked >,
200+ %arg3: !ttg.memdesc <32 x32 xf16 , #shared , #smem , mutable >,
201+ %arg4: i32 ) {
202+ // We need the splat to allow the AxisAnalysis to work during lowering
203+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf16 , #blocked >
204+ %c0_i32 = arith.constant 0 : i32
205+ %c32_i32 = arith.constant 32 : i32
206+ %c31_i32 = arith.constant 31 : i32
207+ %1 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <32 x32 x!tt.ptr <f16 >, #blocked >
208+ %29 = arith.addi %arg4 , %c31_i32 : i32
209+ %30 = arith.divsi %29 , %c32_i32 : i32
210+ %31 = arith.cmpi sgt , %30 , %c0_i32 : i32
211+
212+ %51 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
213+ %52 = tt.expand_dims %51 {axis = 1 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <32 x1 xi32 , #blocked >
214+ %65 = tt.splat %arg4 : i32 -> tensor <32 x1 xi32 , #blocked >
215+ %66 = arith.cmpi slt , %52 , %65 : tensor <32 x1 xi32 , #blocked >
216+ %67 = tt.broadcast %66 : tensor <32 x1 xi1 , #blocked > -> tensor <32 x32 xi1 , #blocked >
217+
218+ %70 = tt.splat %31 : i1 -> tensor <32 x32 xi1 , #blocked >
219+ %71 = arith.andi %70 , %67 : tensor <32 x32 xi1 , #blocked >
220+
221+ // Each thread needs to load 4 elements and we load 1 (sizePerThread) per buffer load instruction
222+ // Note that mask/other alignment is 1 so we need 4 conditionals
223+
224+ // COMMON: rocdl.ds_bpermute
225+ // COMMON: rocdl.ballot
226+ // COMMON: rocdl.raw.ptr.buffer.load.lds
227+ // COMMON: _predicated_store
228+
229+ // COMMON: rocdl.ds_bpermute
230+ // COMMON: rocdl.ballot
231+ // COMMON: rocdl.raw.ptr.buffer.load.lds
232+ // COMMON: _predicated_store
233+
234+ // COMMON: rocdl.ds_bpermute
235+ // COMMON: rocdl.ballot
236+ // COMMON: rocdl.raw.ptr.buffer.load.lds
237+ // COMMON: _predicated_store
238+
239+ // COMMON: rocdl.ds_bpermute
240+ // COMMON: rocdl.ballot
241+ // COMMON: rocdl.raw.ptr.buffer.load.lds
242+ // COMMON: _predicated_store
243+
244+ // COMMON-NOT: rocdl.ds_bpermute
245+ // COMMON-NOT: rocdl.ballot
246+ // COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
247+ // COMMON-NOT: _predicated_store
248+
249+ amdgpu.buffer_load_to_local %arg1 [%arg2 ] mask =%67 other =%cst_0 into %arg3 {OpIdx = #amdgpu.OpIdx <1 >} : <f16 >[tensor <32 x32 xi32 , #blocked >] tensor <32 x32 xf16 , #blocked > -> <32 x32 xf16 , #shared , #smem , mutable >
250+ tt.return
251+ }
252+ }
253+
254+ // -----
255+
256+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 32 ], order = [0 , 1 ]}>
257+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 4 , maxPhase = 16 , order = [0 , 1 ]}>
258+ #smem = #ttg.shared_memory
259+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , ttg.shared = 0 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
260+ // COMMON-LABEL: buffer_load_to_local_swizzled_vectorized_8xf16
261+ tt.func public @buffer_load_to_local_swizzled_vectorized_8xf16 (%arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !ttg.memdesc <64 x64 xf16 , #shared , #smem , mutable >) {
262+ %cst = arith.constant dense <64 > : tensor <1 x64 xi32 , #blocked >
263+ %0 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
264+ %1 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
265+ %2 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
266+ %3 = tt.broadcast %2 : tensor <64 x1 xi32 , #blocked > -> tensor <64 x64 xi32 , #blocked >
267+ %4 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
268+ %5 = arith.muli %4 , %cst : tensor <1 x64 xi32 , #blocked >
269+ %6 = tt.broadcast %5 : tensor <1 x64 xi32 , #blocked > -> tensor <64 x64 xi32 , #blocked >
270+ %7 = arith.addi %3 , %6 : tensor <64 x64 xi32 , #blocked >
271+
272+ // Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
273+ // GFX950: rocdl.make.buffer.rsrc
274+ // GFX950: rocdl.ds_bpermute
275+ // GFX950: rocdl.raw.ptr.buffer.load.lds
276+ // GFX950-NOT: rocdl.raw.ptr.buffer.load.lds
277+
278+ // GFX942 does not support vectorization > 4bytes so we cannot lower it
279+ // GFX942-NOT: rocdl.raw.ptr.buffer.load.lds
280+ // GFX942: amdgpu.buffer_load_to_local
281+ %8 = amdgpu.buffer_load_to_local %arg1 [%7 ] into %arg2 : <f16 >[tensor <64 x64 xi32 , #blocked >] -> <64 x64 xf16 , #shared , #smem , mutable >
282+ tt.return
283+ }
284+ }
0 commit comments