@@ -86,43 +86,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
8686// CHECK: tt.load
8787// CHECK: %[[SLICEA0:.+]] = ttg.local_load
8888// CHECK: %[[SLICEB0:.+]] = ttg.local_load
89- // CHECK: rocdl.sched.barrier 0
9089// CHECK: gpu.barrier
90+ // CHECK: rocdl.sched.barrier 0
9191// CHECK: rocdl.s.setprio 1
9292// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
9393// CHECK: rocdl.s.setprio 0
94- // CHECK: rocdl.sched.barrier 0
9594// CHECK: gpu.barrier
95+ // CHECK: rocdl.sched.barrier 0
9696// CHECK: tt.load
9797// CHECK: %[[SLICEA1:.+]] = ttg.local_load
9898// CHECK: %[[SLICEB1:.+]] = ttg.local_load
99- // CHECK: rocdl.sched.barrier 0
10099// CHECK: gpu.barrier
100+ // CHECK: rocdl.sched.barrier 0
101101// CHECK: rocdl.s.setprio 1
102102// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
103103// CHECK: rocdl.s.setprio 0
104- // CHECK: rocdl.sched.barrier 0
105104// CHECK: gpu.barrier
105+ // CHECK: rocdl.sched.barrier 0
106106// CHECK: %[[SLICEA2:.+]] = ttg.local_load
107107// CHECK: %[[SLICEB2:.+]] = ttg.local_load
108108// CHECK: %[[SLICEA3:.+]] = ttg.local_load
109109// CHECK: %[[SLICEB3:.+]] = ttg.local_load
110- // CHECK: rocdl.sched.barrier 0
111110// CHECK: gpu.barrier
111+ // CHECK: rocdl.sched.barrier 0
112112// CHECK: rocdl.s.setprio 1
113113// CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]]
114114// CHECK: rocdl.s.setprio 0
115- // CHECK: rocdl.sched.barrier 0
116115// CHECK: gpu.barrier
116+ // CHECK: rocdl.sched.barrier 0
117117// CHECK: ttg.local_store
118118// CHECK: ttg.local_store
119- // CHECK: rocdl.sched.barrier 0
120119// CHECK: gpu.barrier
120+ // CHECK: rocdl.sched.barrier 0
121121// CHECK: rocdl.s.setprio 1
122122// CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]]
123123// CHECK: rocdl.s.setprio 0
124- // CHECK: rocdl.sched.barrier 0
125124// CHECK: gpu.barrier
125+ // CHECK: rocdl.sched.barrier 0
126126// CHECK: scf.yield
127127// CHECK: amdgpu.cond_barrier %[[WARPLOW]]
128128
@@ -169,9 +169,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
169169 %27 = tt.load %26 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
170170 %28 = tt.addptr %arg13 , %cst_0 : tensor <64 x256 x!tt.ptr <f16 >, #blocked >, tensor <64 x256 xi32 , #blocked >
171171 %29 = tt.load %28 : tensor <64 x256 x!tt.ptr <f16 >, #blocked >
172- %30 = ttg.local_load %arg15 : !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
173- %31 = ttg.local_load %arg16 : !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable > -> tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 8 }>>
174- %32 = tt.dot %30 , %31 , %arg11 : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>> * tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 8 }>> -> tensor <256 x256 xf32 , #mma >
172+ %30 = ttg.local_load %arg15 : !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
173+ %31 = ttg.local_load %arg16 : !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable > -> tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
174+ %32 = tt.dot %30 , %31 , %arg11 : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <256 x256 xf32 , #mma >
175175 %33 = arith.addi %arg14 , %c1_i32 : i32
176176 %34 = arith.cmpi slt , %33 , %c1_i32 : i32
177177 %35 = arith.select %34 , %33 , %c0_i32 : i32
@@ -189,6 +189,105 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
189189
190190// -----
191191
192+ // CHECK: gpu.barrier
193+ // CHECK: %[[IDX:.+]] = rocdl.workitem.id.x
194+ // CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]]
195+ // CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]]
196+ // CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]]
197+ // CHECK: amdgpu.cond_barrier %[[WARPHIGH]]
198+ // CHECK: scf.for
199+
200+ // CHECK: %[[SLICEA0:.+]] = ttg.local_load
201+ // CHECK: %[[SLICEB0:.+]] = ttg.local_load
202+ // CHECK: rocdl.sched.barrier 0
203+ // CHECK: tt.load
204+ // CHECK: rocdl.sched.barrier 0
205+ // CHECK: %[[SLICEA1:.+]] = ttg.local_load
206+ // CHECK: %[[SLICEB1:.+]] = ttg.local_load
207+ // CHECK: rocdl.sched.barrier 0
208+ // CHECK: tt.load
209+ // CHECK: rocdl.s.barrier
210+ // CHECK: rocdl.sched.barrier 0
211+ // CHECK: rocdl.s.setprio 1
212+ // CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
213+ // CHECK: rocdl.s.setprio 0
214+ // CHECK: gpu.barrier
215+ // CHECK: rocdl.sched.barrier 0
216+ // CHECK: ttg.local_store
217+ // CHECK: ttg.local_store
218+ // CHECK: gpu.barrier
219+ // CHECK: rocdl.sched.barrier 0
220+ // CHECK: rocdl.s.setprio 1
221+ // CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
222+ // CHECK: rocdl.s.setprio 0
223+ // CHECK: gpu.barrier
224+ // CHECK: rocdl.sched.barrier 0
225+ // CHECK: scf.yield
226+ // CHECK: amdgpu.cond_barrier %[[WARPLOW]]
227+
228+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
229+ #blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
230+ #loc = loc (" /home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py" :6 :0 )
231+ #mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 16 ], isTransposed = true }>#shared = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [1 , 0 ], hasLeadingOffset = false }>
232+ #shared1 = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ], hasLeadingOffset = false }>
233+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
234+ tt.func public @pingpong_medium (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
235+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x128 xf32 , #mma >
236+ %c1_i32 = arith.constant 1 : i32
237+ %cst_0 = arith.constant dense <64 > : tensor <64 x128 xi32 , #blocked >
238+ %cst_1 = arith.constant dense <64 > : tensor <256 x64 xi32 , #blocked1 >
239+ %c0_i32 = arith.constant 0 : i32
240+ %c64_i32 = arith.constant 64 : i32
241+ %0 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >
242+ %1 = tt.get_program_id x : i32
243+ %2 = tt.splat %1 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
244+ %3 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
245+ %4 = arith.addi %2 , %3 : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
246+ %5 = tt.expand_dims %4 {axis = 1 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <256 x1 xi32 , #blocked1 >
247+ %6 = tt.splat %arg6 : i32 -> tensor <256 x1 xi32 , #blocked1 >
248+ %7 = arith.muli %5 , %6 : tensor <256 x1 xi32 , #blocked1 >
249+ %8 = tt.addptr %0 , %7 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x1 xi32 , #blocked1 >
250+ %9 = tt.broadcast %8 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
251+ %10 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
252+ %11 = tt.expand_dims %10 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
253+ %12 = tt.broadcast %11 : tensor <1 x64 xi32 , #blocked1 > -> tensor <256 x64 xi32 , #blocked1 >
254+ %13 = tt.addptr %9 , %12 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
255+ %14 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x1 x!tt.ptr <f16 >, #blocked >
256+ %15 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
257+ %16 = tt.expand_dims %15 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
258+ %17 = tt.addptr %14 , %16 : tensor <64 x1 x!tt.ptr <f16 >, #blocked >, tensor <64 x1 xi32 , #blocked >
259+ %18 = tt.broadcast %17 : tensor <64 x1 x!tt.ptr <f16 >, #blocked > -> tensor <64 x128 x!tt.ptr <f16 >, #blocked >
260+ %19 = tt.splat %arg7 : i32 -> tensor <64 x128 xi32 , #blocked >
261+ %20 = tt.addptr %18 , %19 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
262+ %21 = ttg.local_alloc : () -> !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
263+ %22 = ttg.local_alloc : () -> !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
264+ %23 = ttg.memdesc_subview %21 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
265+ %24 = ttg.memdesc_subview %22 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
266+ %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%arg11 = %cst , %arg12 = %13 , %arg13 = %20 , %arg14 = %c0_i32 , %arg15 = %23 , %arg16 = %24 ) -> (tensor <256 x128 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >) : i32 {
267+ %26 = tt.addptr %arg12 , %cst_1 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
268+ %27 = tt.load %26 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
269+ %28 = tt.addptr %arg13 , %cst_0 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
270+ %29 = tt.load %28 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >
271+ %30 = ttg.local_load %arg15 : !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
272+ %31 = ttg.local_load %arg16 : !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
273+ %32 = tt.dot %30 , %31 , %arg11 : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <256 x128 xf32 , #mma >
274+ %33 = arith.addi %arg14 , %c1_i32 : i32
275+ %34 = arith.cmpi slt , %33 , %c1_i32 : i32
276+ %35 = arith.select %34 , %33 , %c0_i32 : i32
277+ %36 = ttg.memdesc_subview %21 [%35 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
278+ ttg.local_store %27 , %36 : tensor <256 x64 xf16 , #blocked1 > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
279+ %37 = ttg.memdesc_subview %22 [%35 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
280+ ttg.local_store %29 , %37 : tensor <64 x128 xf16 , #blocked > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
281+ scf.yield %32 , %26 , %28 , %35 , %36 , %37 : tensor <256 x128 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
282+ }
283+ ttg.local_dealloc %21 : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
284+ ttg.local_dealloc %22 : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
285+ tt.return
286+ }
287+ }
288+
289+ // -----
290+
192291// CHECK-LABEL: pingpong_reject
193292// CHECK-COUNT-2: local_load
194293// CHECK-NOT: local_load
0 commit comments