@@ -422,3 +422,220 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
422422 tt.return
423423 }
424424}
425+
426+ // -----
427+ // CHECK-LABEL: pingpong_medium_dependency
428+
429+ // CHECK: gpu.barrier
430+ // CHECK: %[[IDX:.+]] = rocdl.workitem.id.x
431+ // CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]]
432+ // CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]]
433+ // CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]]
434+ // CHECK: amdgpu.cond_barrier %[[WARPHIGH]]
435+ // CHECK: scf.for
436+
437+ // CHECK: %[[SLICEA0:.+]] = ttg.local_load
438+ // CHECK: %[[SLICEB0:.+]] = ttg.local_load
439+ // CHECK: rocdl.sched.barrier 0
440+ // CHECK: tt.load
441+ // CHECK: rocdl.sched.barrier 0
442+ // CHECK: %[[SLICEA1:.+]] = ttg.local_load
443+ // CHECK: %[[SLICEB1:.+]] = ttg.local_load
444+ // CHECK: rocdl.sched.barrier 0
445+ // CHECK: tt.load
446+ // CHECK: rocdl.s.barrier
447+ // CHECK: rocdl.sched.barrier 0
448+ // CHECK: rocdl.s.setprio 1
449+ // CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
450+ // CHECK: rocdl.s.setprio 0
451+ // CHECK: gpu.barrier
452+ // CHECK: rocdl.sched.barrier 0
453+ // CHECK: ttg.local_store
454+ // CHECK: ttg.local_store
455+ // CHECK: gpu.barrier
456+ // CHECK: rocdl.sched.barrier 0
457+ // CHECK: rocdl.s.setprio 1
458+ // CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
459+ // CHECK: rocdl.s.setprio 0
460+ // CHECK: gpu.barrier
461+ // CHECK: rocdl.sched.barrier 0
462+ // CHECK: scf.yield
463+ // CHECK: amdgpu.cond_barrier %[[WARPLOW]]
464+
465+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
466+ #blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
467+ #mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 16 ], isTransposed = true }>#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [1 , 0 ]}>
468+ #shared1 = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
469+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
470+ tt.func public @pingpong_medium_dependency (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
471+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x128 xf32 , #mma >
472+ %c1_i32 = arith.constant 1 : i32
473+ %cst_0 = arith.constant dense <64 > : tensor <64 x128 xi32 , #blocked >
474+ %cst_1 = arith.constant dense <64 > : tensor <256 x64 xi32 , #blocked1 >
475+ %cst_2 = arith.constant dense <1.000000e+00 > : tensor <256 x128 xf32 , #mma >
476+ %c0_i32 = arith.constant 0 : i32
477+ %c64_i32 = arith.constant 64 : i32
478+ %0 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >
479+ %1 = tt.get_program_id x : i32
480+ %2 = tt.splat %1 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
481+ %3 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
482+ %4 = arith.addi %2 , %3 : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
483+ %5 = tt.expand_dims %4 {axis = 1 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <256 x1 xi32 , #blocked1 >
484+ %6 = tt.splat %arg6 : i32 -> tensor <256 x1 xi32 , #blocked1 >
485+ %7 = arith.muli %5 , %6 : tensor <256 x1 xi32 , #blocked1 >
486+ %8 = tt.addptr %0 , %7 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x1 xi32 , #blocked1 >
487+ %9 = tt.broadcast %8 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
488+ %10 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
489+ %11 = tt.expand_dims %10 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
490+ %12 = tt.broadcast %11 : tensor <1 x64 xi32 , #blocked1 > -> tensor <256 x64 xi32 , #blocked1 >
491+ %13 = tt.addptr %9 , %12 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
492+ %14 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x1 x!tt.ptr <f16 >, #blocked >
493+ %15 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
494+ %16 = tt.expand_dims %15 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
495+ %17 = tt.addptr %14 , %16 : tensor <64 x1 x!tt.ptr <f16 >, #blocked >, tensor <64 x1 xi32 , #blocked >
496+ %18 = tt.broadcast %17 : tensor <64 x1 x!tt.ptr <f16 >, #blocked > -> tensor <64 x128 x!tt.ptr <f16 >, #blocked >
497+ %19 = tt.splat %arg7 : i32 -> tensor <64 x128 xi32 , #blocked >
498+ %20 = tt.addptr %18 , %19 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
499+ %21 = ttg.local_alloc : () -> !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
500+ %22 = ttg.local_alloc : () -> !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
501+ %23 = ttg.memdesc_subview %21 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
502+ %24 = ttg.memdesc_subview %22 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
503+ %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%arg11 = %cst , %arg12 = %13 , %arg13 = %20 , %arg14 = %c0_i32 , %arg15 = %23 , %arg16 = %24 ) -> (tensor <256 x128 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >) : i32 {
504+ %26 = tt.addptr %arg12 , %cst_1 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
505+ %27 = tt.load %26 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
506+ %28 = tt.addptr %arg13 , %cst_0 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >, tensor <64 x128 xi32 , #blocked >
507+ %29 = tt.load %28 : tensor <64 x128 x!tt.ptr <f16 >, #blocked >
508+ %30 = ttg.local_load %arg15 : !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
509+ %31 = ttg.local_load %arg16 : !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
510+ %32 = tt.dot %30 , %31 , %arg11 : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <256 x128 xf32 , #mma >
511+ %33 = arith.addf %32 , %cst_2 : tensor <256 x128 xf32 , #mma >
512+ %34 = arith.addi %arg14 , %c1_i32 : i32
513+ %35 = arith.cmpi slt , %34 , %c1_i32 : i32
514+ %36 = arith.select %35 , %34 , %c0_i32 : i32
515+ %37 = ttg.memdesc_subview %21 [%36 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
516+ ttg.local_store %27 , %37 : tensor <256 x64 xf16 , #blocked1 > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
517+ %38 = ttg.memdesc_subview %22 [%36 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
518+ ttg.local_store %29 , %38 : tensor <64 x128 xf16 , #blocked > -> !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
519+ scf.yield %33 , %26 , %28 , %36 , %37 , %38 : tensor <256 x128 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x128 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
520+ }
521+ ttg.local_dealloc %21 : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
522+ ttg.local_dealloc %22 : !ttg.memdesc <1 x64 x128 xf16 , #shared1 , #ttg.shared_memory , mutable >
523+ tt.return
524+ }
525+ }
526+
527+ // -----
528+ // CHECK-LABEL: pingpong_large_dependency
529+
530+ // CHECK: gpu.barrier
531+ // CHECK: %[[IDX:.+]] = rocdl.workitem.id.x
532+ // CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]]
533+ // CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]]
534+ // CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]]
535+ // CHECK: amdgpu.cond_barrier %[[WARPHIGH]]
536+ // CHECK: scf.for
537+ // CHECK: tt.load
538+ // CHECK: %[[SLICEA0:.+]] = ttg.local_load
539+ // CHECK: %[[SLICEB0:.+]] = ttg.local_load
540+ // CHECK: gpu.barrier
541+ // CHECK: rocdl.sched.barrier 0
542+ // CHECK: rocdl.s.setprio 1
543+ // CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
544+ // CHECK: rocdl.s.setprio 0
545+ // CHECK: gpu.barrier
546+ // CHECK: rocdl.sched.barrier 0
547+ // CHECK: tt.load
548+ // CHECK: %[[SLICEA1:.+]] = ttg.local_load
549+ // CHECK: %[[SLICEB1:.+]] = ttg.local_load
550+ // CHECK: gpu.barrier
551+ // CHECK: rocdl.sched.barrier 0
552+ // CHECK: rocdl.s.setprio 1
553+ // CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
554+ // CHECK: rocdl.s.setprio 0
555+ // CHECK: gpu.barrier
556+ // CHECK: rocdl.sched.barrier 0
557+ // CHECK: %[[SLICEA2:.+]] = ttg.local_load
558+ // CHECK: %[[SLICEB2:.+]] = ttg.local_load
559+ // CHECK: %[[SLICEA3:.+]] = ttg.local_load
560+ // CHECK: %[[SLICEB3:.+]] = ttg.local_load
561+ // CHECK: gpu.barrier
562+ // CHECK: rocdl.sched.barrier 0
563+ // CHECK: rocdl.s.setprio 1
564+ // CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]]
565+ // CHECK: rocdl.s.setprio 0
566+ // CHECK: gpu.barrier
567+ // CHECK: rocdl.sched.barrier 0
568+ // CHECK: ttg.local_store
569+ // CHECK: ttg.local_store
570+ // CHECK: gpu.barrier
571+ // CHECK: rocdl.sched.barrier 0
572+ // CHECK: rocdl.s.setprio 1
573+ // CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]]
574+ // CHECK: rocdl.s.setprio 0
575+ // CHECK: gpu.barrier
576+ // CHECK: rocdl.sched.barrier 0
577+ // CHECK: scf.yield
578+ // CHECK: amdgpu.cond_barrier %[[WARPLOW]]
579+
580+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
581+ #blocked1 = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
582+ #mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [16 , 16 ], isTransposed = true }>#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [1 , 0 ]}>
583+ #shared1 = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
584+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
585+ tt.func public @pingpong_large_dependency (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg3: i32 {tt.divisibility = 16 : i32 }, %arg4: i32 {tt.divisibility = 16 : i32 }, %arg5: i32 {tt.divisibility = 16 : i32 }, %arg6: i32 {tt.divisibility = 16 : i32 }, %arg7: i32 {tt.divisibility = 16 : i32 }, %arg8: i32 {tt.divisibility = 16 : i32 }, %arg9: i32 {tt.divisibility = 16 : i32 }) attributes {noinline = false } {
586+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
587+ %c1_i32 = arith.constant 1 : i32
588+ %cst_0 = arith.constant dense <64 > : tensor <64 x256 xi32 , #blocked >
589+ %cst_1 = arith.constant dense <64 > : tensor <256 x64 xi32 , #blocked1 >
590+ %cst_2 = arith.constant dense <1.000000e+00 > : tensor <256 x256 xf32 , #mma >
591+ %c0_i32 = arith.constant 0 : i32
592+ %c63_i32 = arith.constant 63 : i32
593+ %c64_i32 = arith.constant 64 : i32
594+ %0 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >
595+ %1 = tt.get_program_id x : i32
596+ %2 = tt.splat %1 : i32 -> tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
597+ %3 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
598+ %4 = arith.addi %2 , %3 : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
599+ %5 = tt.expand_dims %4 {axis = 1 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <256 x1 xi32 , #blocked1 >
600+ %6 = tt.splat %arg6 : i32 -> tensor <256 x1 xi32 , #blocked1 >
601+ %7 = arith.muli %5 , %6 : tensor <256 x1 xi32 , #blocked1 >
602+ %8 = tt.addptr %0 , %7 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x1 xi32 , #blocked1 >
603+ %9 = tt.broadcast %8 : tensor <256 x1 x!tt.ptr <f16 >, #blocked1 > -> tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
604+ %10 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
605+ %11 = tt.expand_dims %10 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x64 xi32 , #blocked1 >
606+ %12 = tt.broadcast %11 : tensor <1 x64 xi32 , #blocked1 > -> tensor <256 x64 xi32 , #blocked1 >
607+ %13 = tt.addptr %9 , %12 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
608+ %14 = tt.splat %arg1 : !tt.ptr <f16 > -> tensor <64 x1 x!tt.ptr <f16 >, #blocked >
609+ %15 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
610+ %16 = tt.expand_dims %15 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
611+ %17 = tt.addptr %14 , %16 : tensor <64 x1 x!tt.ptr <f16 >, #blocked >, tensor <64 x1 xi32 , #blocked >
612+ %18 = tt.broadcast %17 : tensor <64 x1 x!tt.ptr <f16 >, #blocked > -> tensor <64 x256 x!tt.ptr <f16 >, #blocked >
613+ %19 = tt.splat %arg7 : i32 -> tensor <64 x256 xi32 , #blocked >
614+ %20 = tt.addptr %18 , %19 : tensor <64 x256 x!tt.ptr <f16 >, #blocked >, tensor <64 x256 xi32 , #blocked >
615+ %21 = ttg.local_alloc : () -> !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
616+ %22 = ttg.local_alloc : () -> !ttg.memdesc <1 x64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable >
617+ %23 = ttg.memdesc_subview %21 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
618+ %24 = ttg.memdesc_subview %22 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable >
619+ %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args (%arg11 = %cst , %arg12 = %13 , %arg13 = %20 , %arg14 = %c0_i32 , %arg15 = %23 , %arg16 = %24 ) -> (tensor <256 x256 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x256 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable >) : i32 {
620+ %26 = tt.addptr %arg12 , %cst_1 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <256 x64 xi32 , #blocked1 >
621+ %27 = tt.load %26 : tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >
622+ %28 = tt.addptr %arg13 , %cst_0 : tensor <64 x256 x!tt.ptr <f16 >, #blocked >, tensor <64 x256 xi32 , #blocked >
623+ %29 = tt.load %28 : tensor <64 x256 x!tt.ptr <f16 >, #blocked >
624+ %30 = ttg.local_load %arg15 : !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
625+ %31 = ttg.local_load %arg16 : !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable > -> tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
626+ %32 = tt.dot %30 , %31 , %arg11 : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <64 x256 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <256 x256 xf32 , #mma >
627+ %33 = arith.addf %32 , %cst_2 : tensor <256 x256 xf32 , #mma >
628+ %34 = arith.addi %arg14 , %c1_i32 : i32
629+ %35 = arith.cmpi slt , %34 , %c1_i32 : i32
630+ %36 = arith.select %35 , %34 , %c0_i32 : i32
631+ %37 = ttg.memdesc_subview %21 [%36 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
632+ ttg.local_store %27 , %37 : tensor <256 x64 xf16 , #blocked1 > -> !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
633+ %38 = ttg.memdesc_subview %22 [%36 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable > -> !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable >
634+ ttg.local_store %29 , %38 : tensor <64 x256 xf16 , #blocked > -> !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable >
635+ scf.yield %33 , %26 , %28 , %36 , %37 , %38 : tensor <256 x256 xf32 , #mma >, tensor <256 x64 x!tt.ptr <f16 >, #blocked1 >, tensor <64 x256 x!tt.ptr <f16 >, #blocked >, i32 , !ttg.memdesc <256 x64 xf16 , #shared , #ttg.shared_memory , mutable >, !ttg.memdesc <64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable >
636+ }
637+ ttg.local_dealloc %21 : !ttg.memdesc <1 x256 x64 xf16 , #shared , #ttg.shared_memory , mutable >
638+ ttg.local_dealloc %22 : !ttg.memdesc <1 x64 x256 xf16 , #shared1 , #ttg.shared_memory , mutable >
639+ tt.return
640+ }
641+ }
0 commit comments