Skip to content

Commit 4c298b9

Browse files
njriasanNick Riasanovsky
authored andcommitted
[AMD] Fix block ping-pong reordering for persistent matmul (triton-lang#5986)
For larger tiles, Ping Pong Scheduler reorders memory and compute operations into slices. However, the current implementation makes an incorrect assumption that it is legal/safe to simply move the second dot product after the current local load prefetch. As a result for persistent matmul kernels that may contain an epilogue using the result multiple times in the loop, which results in invalid code. The proper way to more robustly handle this situation should be to move the prefetch before the epilogue for these code kernels so that the end result of the dot product is always available at the same point in the code. --------- Co-authored-by: Nick Riasanovsky <[email protected]>
1 parent a8d53ad commit 4c298b9

File tree

2 files changed

+256
-9
lines changed

2 files changed

+256
-9
lines changed

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,220 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
422422
tt.return
423423
}
424424
}
425+
426+
// -----
427+
// CHECK-LABEL: pingpong_medium_dependency
428+
429+
// CHECK: gpu.barrier
430+
// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x
431+
// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]]
432+
// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]]
433+
// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]]
434+
// CHECK: amdgpu.cond_barrier %[[WARPHIGH]]
435+
// CHECK: scf.for
436+
437+
// CHECK: %[[SLICEA0:.+]] = ttg.local_load
438+
// CHECK: %[[SLICEB0:.+]] = ttg.local_load
439+
// CHECK: rocdl.sched.barrier 0
440+
// CHECK: tt.load
441+
// CHECK: rocdl.sched.barrier 0
442+
// CHECK: %[[SLICEA1:.+]] = ttg.local_load
443+
// CHECK: %[[SLICEB1:.+]] = ttg.local_load
444+
// CHECK: rocdl.sched.barrier 0
445+
// CHECK: tt.load
446+
// CHECK: rocdl.s.barrier
447+
// CHECK: rocdl.sched.barrier 0
448+
// CHECK: rocdl.s.setprio 1
449+
// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
450+
// CHECK: rocdl.s.setprio 0
451+
// CHECK: gpu.barrier
452+
// CHECK: rocdl.sched.barrier 0
453+
// CHECK: ttg.local_store
454+
// CHECK: ttg.local_store
455+
// CHECK: gpu.barrier
456+
// CHECK: rocdl.sched.barrier 0
457+
// CHECK: rocdl.s.setprio 1
458+
// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
459+
// CHECK: rocdl.s.setprio 0
460+
// CHECK: gpu.barrier
461+
// CHECK: rocdl.sched.barrier 0
462+
// CHECK: scf.yield
463+
// CHECK: amdgpu.cond_barrier %[[WARPLOW]]
464+
465+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
466+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
467+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>
468+
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
469+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
470+
tt.func public @pingpong_medium_dependency(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
471+
%cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma>
472+
%c1_i32 = arith.constant 1 : i32
473+
%cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked>
474+
%cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1>
475+
%cst_2 = arith.constant dense<1.000000e+00> : tensor<256x128xf32, #mma>
476+
%c0_i32 = arith.constant 0 : i32
477+
%c64_i32 = arith.constant 64 : i32
478+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1>
479+
%1 = tt.get_program_id x : i32
480+
%2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
481+
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
482+
%4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
483+
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
484+
%6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1>
485+
%7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1>
486+
%8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1>
487+
%9 = tt.broadcast %8 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x64x!tt.ptr<f16>, #blocked1>
488+
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
489+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
490+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1>
491+
%13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
492+
%14 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
493+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
494+
%16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
495+
%17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
496+
%18 = tt.broadcast %17 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x128x!tt.ptr<f16>, #blocked>
497+
%19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked>
498+
%20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
499+
%21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
500+
%22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
501+
%23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
502+
%24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
503+
%25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 {
504+
%26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
505+
%27 = tt.load %26 : tensor<256x64x!tt.ptr<f16>, #blocked1>
506+
%28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr<f16>, #blocked>, tensor<64x128xi32, #blocked>
507+
%29 = tt.load %28 : tensor<64x128x!tt.ptr<f16>, #blocked>
508+
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
509+
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
510+
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma>
511+
%33 = arith.addf %32, %cst_2 : tensor<256x128xf32, #mma>
512+
%34 = arith.addi %arg14, %c1_i32 : i32
513+
%35 = arith.cmpi slt, %34, %c1_i32 : i32
514+
%36 = arith.select %35, %34, %c0_i32 : i32
515+
%37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
516+
ttg.local_store %27, %37 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
517+
%38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
518+
ttg.local_store %29, %38 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
519+
scf.yield %33, %26, %28, %36, %37, %38 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x128x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>
520+
}
521+
ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
522+
ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable>
523+
tt.return
524+
}
525+
}
526+
527+
// -----
528+
// CHECK-LABEL: pingpong_large_dependency
529+
530+
// CHECK: gpu.barrier
531+
// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x
532+
// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]]
533+
// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]]
534+
// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]]
535+
// CHECK: amdgpu.cond_barrier %[[WARPHIGH]]
536+
// CHECK: scf.for
537+
// CHECK: tt.load
538+
// CHECK: %[[SLICEA0:.+]] = ttg.local_load
539+
// CHECK: %[[SLICEB0:.+]] = ttg.local_load
540+
// CHECK: gpu.barrier
541+
// CHECK: rocdl.sched.barrier 0
542+
// CHECK: rocdl.s.setprio 1
543+
// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]]
544+
// CHECK: rocdl.s.setprio 0
545+
// CHECK: gpu.barrier
546+
// CHECK: rocdl.sched.barrier 0
547+
// CHECK: tt.load
548+
// CHECK: %[[SLICEA1:.+]] = ttg.local_load
549+
// CHECK: %[[SLICEB1:.+]] = ttg.local_load
550+
// CHECK: gpu.barrier
551+
// CHECK: rocdl.sched.barrier 0
552+
// CHECK: rocdl.s.setprio 1
553+
// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]]
554+
// CHECK: rocdl.s.setprio 0
555+
// CHECK: gpu.barrier
556+
// CHECK: rocdl.sched.barrier 0
557+
// CHECK: %[[SLICEA2:.+]] = ttg.local_load
558+
// CHECK: %[[SLICEB2:.+]] = ttg.local_load
559+
// CHECK: %[[SLICEA3:.+]] = ttg.local_load
560+
// CHECK: %[[SLICEB3:.+]] = ttg.local_load
561+
// CHECK: gpu.barrier
562+
// CHECK: rocdl.sched.barrier 0
563+
// CHECK: rocdl.s.setprio 1
564+
// CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]]
565+
// CHECK: rocdl.s.setprio 0
566+
// CHECK: gpu.barrier
567+
// CHECK: rocdl.sched.barrier 0
568+
// CHECK: ttg.local_store
569+
// CHECK: ttg.local_store
570+
// CHECK: gpu.barrier
571+
// CHECK: rocdl.sched.barrier 0
572+
// CHECK: rocdl.s.setprio 1
573+
// CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]]
574+
// CHECK: rocdl.s.setprio 0
575+
// CHECK: gpu.barrier
576+
// CHECK: rocdl.sched.barrier 0
577+
// CHECK: scf.yield
578+
// CHECK: amdgpu.cond_barrier %[[WARPLOW]]
579+
580+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
581+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
582+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>
583+
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
584+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
585+
tt.func public @pingpong_large_dependency(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
586+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
587+
%c1_i32 = arith.constant 1 : i32
588+
%cst_0 = arith.constant dense<64> : tensor<64x256xi32, #blocked>
589+
%cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1>
590+
%cst_2 = arith.constant dense<1.000000e+00> : tensor<256x256xf32, #mma>
591+
%c0_i32 = arith.constant 0 : i32
592+
%c63_i32 = arith.constant 63: i32
593+
%c64_i32 = arith.constant 64 : i32
594+
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1>
595+
%1 = tt.get_program_id x : i32
596+
%2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
597+
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
598+
%4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
599+
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
600+
%6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1>
601+
%7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1>
602+
%8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1>
603+
%9 = tt.broadcast %8 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x64x!tt.ptr<f16>, #blocked1>
604+
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
605+
%11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
606+
%12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1>
607+
%13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
608+
%14 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x1x!tt.ptr<f16>, #blocked>
609+
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
610+
%16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
611+
%17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr<f16>, #blocked>, tensor<64x1xi32, #blocked>
612+
%18 = tt.broadcast %17 : tensor<64x1x!tt.ptr<f16>, #blocked> -> tensor<64x256x!tt.ptr<f16>, #blocked>
613+
%19 = tt.splat %arg7 : i32 -> tensor<64x256xi32, #blocked>
614+
%20 = tt.addptr %18, %19 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
615+
%21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
616+
%22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable>
617+
%23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
618+
%24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>
619+
%25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>) : i32 {
620+
%26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
621+
%27 = tt.load %26 : tensor<256x64x!tt.ptr<f16>, #blocked1>
622+
%28 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
623+
%29 = tt.load %28 : tensor<64x256x!tt.ptr<f16>, #blocked>
624+
%30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
625+
%31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
626+
%32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma>
627+
%33 = arith.addf %32, %cst_2 : tensor<256x256xf32, #mma>
628+
%34 = arith.addi %arg14, %c1_i32 : i32
629+
%35 = arith.cmpi slt, %34, %c1_i32 : i32
630+
%36 = arith.select %35, %34, %c0_i32 : i32
631+
%37 = ttg.memdesc_subview %21[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
632+
ttg.local_store %27, %37 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
633+
%38 = ttg.memdesc_subview %22[%36, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>
634+
ttg.local_store %29, %38 : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>
635+
scf.yield %33, %26, %28, %36, %37, %38 : tensor<256x256xf32, #mma>, tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<f16>, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable>
636+
}
637+
ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable>
638+
ttg.local_dealloc %22 : !ttg.memdesc<1x64x256xf16, #shared1, #ttg.shared_memory, mutable>
639+
tt.return
640+
}
641+
}

0 commit comments

Comments
 (0)