@@ -128,7 +128,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
128128
129129#blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
130130#blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
131- #loc = loc (" /home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py" :6 :0 )
132131#mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 16 ], isTransposed = true }>#shared = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [1 , 0 ], hasLeadingOffset = false }>
133132#shared1 = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ], hasLeadingOffset = false }>
134133module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
@@ -227,7 +226,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
227226
228227#blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
229228#blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
230- #loc = loc (" /home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py" :6 :0 )
231229#mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 16 ], isTransposed = true }>#shared = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [1 , 0 ], hasLeadingOffset = false }>
232230#shared1 = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ], hasLeadingOffset = false }>
233231module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
@@ -288,6 +286,77 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
288286
289287// -----
290288
289+ // CHECK-LABEL: pingpong_medium_cast
290+ // CHECK-COUNT-2: local_load
291+ // CHECK-NOT: setprio
292+ // CHECK-NOT: barrier
293+
294+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
295+ #blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
296+ #mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 16 ], isTransposed = true }>#shared = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [1 , 0 ], hasLeadingOffset = false }>
297+ #shared1 = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ], hasLeadingOffset = false }>
298+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
299+ tt.func public @pingpong_medium_cast (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
300+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x128 xf32 , #mma >
301+ %c1_i32 = arith.constant 1 : i32
302+ %cst_0 = arith.constant dense <64 > : tensor <64 x128 xi32 , #blocked >
303+ %cst_1 = arith.constant dense <64 > : tensor <256 x64 xi32 , #blocked1 >
304+ %c0_i32 = arith.constant 0 : i32
305+ %c64_i32 = arith.constant 64 : i32
306+ %0 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >
307+ %1 = tt.get_program_id x : i32
308+ %2 = tt.splat %1 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
309+ %3 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
310+ %4 = arith.addi %2 , %3 : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
311+ %5 = tt.expand_dims %4 {axis = 1 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <256 x1 xi32 , #blocked1 >
312+ %6 = tt.splat %arg6 : i32 -> tensor <256 x1 xi32 , #blocked1 >
313+ %7 = arith.muli %5 , %6 : tensor <256 x1 xi32 , #blocked1 >
314+ %8 = tt.addptr %0 , %7 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x1 xi32 , #blocked1 >
315+ %9 = tt.broadcast %8 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
316+ %10 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
317+ %11 = tt.expand_dims %10 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
318+ %12 = tt.broadcast %11 : tensor <1 x64 xi32 , #blocked1 > -> tensor <256 x64 xi32 , #blocked1 >
319+ %13 = tt.addptr %9 , %12 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
320+ %14 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x1 x!tt.ptr <f16 >, #blocked >
321+ %15 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
322+ %16 = tt.expand_dims %15 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
323+ %17 = tt.addptr %14 , %16 : tensor <64 x1 x!tt.ptr <f16 >, #blocked >, tensor <64 x1 xi32 , #blocked >
324+ %18 = tt.broadcast %17 : tensor <64 x1 x!tt.ptr <f16 >, #blocked > -> tensor <64 x128 x!tt.ptr <f16 >, #blocked >
325+ %19 = tt.splat %arg7 : i32 -> tensor <64 x128 xi32 , #blocked >
326+ %20 = tt.addptr %18 , %19 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
327+ %21 = ttg.local_alloc : () -> !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
328+ %22 = ttg.local_alloc : () -> !ttg.memdesc <1 x64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable >
329+ %23 = ttg.memdesc_subview %21 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
330+ %24 = ttg.memdesc_subview %22 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable >
331+ %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%arg11 = %cst , %arg12 = %13 , %arg13 = %20 , %arg14 = %c0_i32 , %arg15 = %23 , %arg16 = %24 ) -> (tensor <256 x128 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable >) : i32 {
332+ %26 = tt.addptr %arg12 , %cst_1 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
333+ %27 = tt.load %26 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
334+ %28 = tt.addptr %arg13 , %cst_0 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
335+ %29 = tt.load %28 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >
336+ %cast2 = tt.bitcast %29 : tensor <64 x128 xf16 , #blocked > -> tensor <64 x128 xi16 , #blocked >
337+ %30 = ttg.local_load %arg15 : !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
338+ %31 = ttg.local_load %arg16 : !ttg.memdesc <64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable > -> tensor <64 x128 xi16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
339+ %cast = tt.bitcast %31 : tensor <64 x128 xi16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
340+ %32 = tt.dot %30 , %cast , %arg11 : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <256 x128 xf32 , #mma >
341+ %33 = arith.addi %arg14 , %c1_i32 : i32
342+ %34 = arith.cmpi slt , %33 , %c1_i32 : i32
343+ %35 = arith.select %34 , %33 , %c0_i32 : i32
344+ %36 = ttg.memdesc_subview %21 [%35 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
345+ ttg.local_store %27 , %36 : tensor <256 x64 xf16 , #blocked1 > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
346+ %37 = ttg.memdesc_subview %22 [%35 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable >
347+ ttg.local_store %cast2 , %37 : tensor <64 x128 xi16 , #blocked > -> !ttg.memdesc <64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable >
348+ scf.yield %32 , %26 , %28 , %35 , %36 , %37 : tensor <256 x128 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable >
349+ }
350+ ttg.local_dealloc %21 : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
351+ ttg.local_dealloc %22 : !ttg.memdesc <1 x64 x128 xi16 , #shared1 , #ttg.shared_memory , mutable >
352+ tt.return
353+ }
354+ }
355+
356+
357+ // -----
358+
359+
291360// CHECK-LABEL: pingpong_reject
292361// CHECK-COUNT-2: local_load
293362// CHECK-NOT: local_load
@@ -296,7 +365,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
296365
297366#blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
298367#blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
299- #loc = loc (" /home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py" :6 :0 )
300368#mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 16 ], isTransposed = true }>#shared = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [1 , 0 ], hasLeadingOffset = false }>
301369#shared1 = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ], hasLeadingOffset = false }>
302370module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
0 commit comments