@@ -423,6 +423,98 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
423423 }
424424}
425425
426+ // -----
427+
428+ //CHECK-LABEL: pingpong_small_prologue_load
429+ //CHECK: ttg.local_load
430+ //CHECK: rocdl.s.setprio 1
431+ //CHECK: tt.load
432+ //CHECK: rocdl.sched.barrier
433+ //CHECK: ttg.local_load
434+ //CHECK: rocdl.s.setprio 0
435+ //CHECK: tt.load
436+ //CHECK: rocdl.sched.barrier
437+ //CHECK: rocdl.s.setprio 1
438+ //CHECK: tt.dot
439+ //CHECK: rocdl.s.setprio 0
440+
441+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 4 ], order = [0 , 1 ]}>
442+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
443+ #mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 2 ], instrShape = [16 , 16 ], isTransposed = true }>
444+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ]}>
445+ #shared1 = #ttg.swizzled_shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [0 , 1 ]}>
446+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
447+ tt.func public @pingpong_small_prologue_load (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
448+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
449+ %c1_i32 = arith.constant 1 : i32
450+ %cst_0 = arith.constant dense <64 > : tensor <64 x128 xi32 , #blocked >
451+ %cst_1 = arith.constant dense <64 > : tensor <128 x64 xi32 , #blocked1 >
452+ %cst_2 = arith.constant dense <0.000000e+00 > : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
453+ %c0_i32 = arith.constant 0 : i32
454+ %c64_i32 = arith.constant 64 : i32
455+ %0 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >
456+ %1 = tt.get_program_id x : i32
457+ %2 = tt.splat %1 : i32 -> tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
458+ %3 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
459+ %4 = arith.addi %2 , %3 : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
460+ %5 = tt.expand_dims %4 {axis = 1 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <128 x1 xi32 , #blocked1 >
461+ %6 = tt.splat %arg6 : i32 -> tensor <128 x1 xi32 , #blocked1 >
462+ %7 = arith.muli %5 , %6 : tensor <128 x1 xi32 , #blocked1 >
463+ %8 = tt.addptr %0 , %7 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x1 xi32 , #blocked1 >
464+ %9 = tt.broadcast %8 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
465+ %10 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
466+ %11 = tt.expand_dims %10 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
467+ %12 = tt.broadcast %11 : tensor <1 x64 xi32 , #blocked1 > -> tensor <128 x64 xi32 , #blocked1 >
468+ %13 = tt.addptr %9 , %12 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
469+ %14 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x1 x!tt.ptr <f16 >, #blocked >
470+ %15 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
471+ %16 = tt.expand_dims %15 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
472+ %17 = tt.addptr %14 , %16 : tensor <64 x1 x!tt.ptr <f16 >, #blocked >, tensor <64 x1 xi32 , #blocked >
473+ %18 = tt.broadcast %17 : tensor <64 x1 x!tt.ptr <f16 >, #blocked > -> tensor <64 x128 x!tt.ptr <f16 >, #blocked >
474+ %19 = tt.splat %arg7 : i32 -> tensor <64 x128 xi32 , #blocked >
475+ %20 = tt.addptr %18 , %19 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
476+ %21 = ttg.local_alloc : () -> !ttg.memdesc <1 x128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
477+ %22 = ttg.local_alloc : () -> !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
478+ %23 = ttg.memdesc_subview %21 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x128 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
479+ %24 = ttg.memdesc_subview %22 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
480+ %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%arg11 = %cst , %arg12 = %13 , %arg13 = %20 , %arg14 = %c0_i32 , %arg15 = %23 , %arg16 = %24 ) -> (tensor <128 x128 xf32 , #mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >) : i32 {
481+ %26 = arith.cmpi eq , %arg10 , %c0_i32: i32
482+ %27 = scf.if %26 -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>> {
483+ %28 = tt.splat %arg2 : !tt.ptr <f16 > -> tensor <128 x1 x!tt.ptr <f16 >, #blocked1 >
484+ %29 = tt.broadcast %28 : tensor <128 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
485+ %30 = tt.load %29 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
486+ %31 = ttg.local_alloc : () -> !ttg.memdesc <1 x128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
487+ %32 = ttg.memdesc_subview %31 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x128 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
488+ ttg.local_store %30 , %32 : tensor <128 x64 xf16 , #blocked1 > -> !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
489+ %33 = ttg.local_load %32 : !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
490+ scf.yield %33 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
491+ } else {
492+ scf.yield %cst_2 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
493+ }
494+ %34 = tt.addptr %arg12 , %cst_1 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x64 xi32 , #blocked1 >
495+ %35 = tt.load %34 : tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >
496+ %36 = tt.addptr %arg13 , %cst_0 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
497+ %37 = tt.load %36 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >
498+ %38 = ttg.local_load %arg15 : !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
499+ %39 = arith.addf %38 , %27: tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>>
500+ %40 = ttg.local_load %arg16 : !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 8 }>>
501+ %41 = tt.dot %39 , %40 , %arg11 : tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 8 }>> * tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 8 }>> -> tensor <128 x128 xf32 , #mma >
502+ %42 = arith.addi %arg14 , %c1_i32 : i32
503+ %43 = arith.cmpi slt , %42 , %c1_i32 : i32
504+ %44 = arith.select %43 , %42 , %c0_i32 : i32
505+ %45 = ttg.memdesc_subview %21 [%44 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x128 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
506+ ttg.local_store %35 , %45 : tensor <128 x64 xf16 , #blocked1 > -> !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
507+ %46 = ttg.memdesc_subview %22 [%44 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
508+ ttg.local_store %37 , %46 : tensor <64 x128 xf16 , #blocked > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
509+ scf.yield %41 , %34 , %36 , %44 , %45 , %46 : tensor <128 x128 xf32 , #mma >, tensor <128 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <128 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
510+ }
511+ ttg.local_dealloc %21 : !ttg.memdesc <1 x128 x64 xf16 , #shared , #ttg.shared_memory , mutable >
512+ ttg.local_dealloc %22 : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
513+ tt.return
514+ }
515+ }
516+
517+
426518// -----
427519// CHECK-LABEL: pingpong_medium_dependency
428520
0 commit comments