Skip to content

Commit 7c32dad

Browse files
authored
[AMD] Improve s_xxx instruction placement for fav3 pingpong (#8535)
This PR improves the placement of s_xxx instructions between clusters in FAv3. It has two benefits: 1. s_xxx instructions are not placed in the memory cluster so they can be co-issued with v_xxx instructions 2. The idle cycles between clusters are reduced, which helps reduce the stall cycles for some v_xxx instructions. Perf impact on FAv3, with nheads=64 and D=128 causal | batch | N_CTX | this pr | main -- | -- | -- | -- | -- 0 | 32 | 512 | 655 | 649 0 | 16 | 1024 | 824 | 811 0 | 8 | 2048 | 972 | 932 0 | 4 | 4096 | 1046 | 1002 0 | 2 | 8192 | 1048 | 1007 0 | 1 | 16384 | 1064 | 1020 1 | 32 | 512 | 351 | 350 1 | 16 | 1024 | 550 | 541 1 | 8 | 2048 | 728 | 707 1 | 4 | 4096 | 862 | 827 1 | 2 | 8192 | 960 | 914 1 | 1 | 16384 | 895 | 865
1 parent 618ec40 commit 7c32dad

File tree

3 files changed

+143
-31
lines changed

3 files changed

+143
-31
lines changed

test/TritonGPU/amd/amd-block-pingpong-chained-dots.mlir

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
99
// CHECK-LABEL: chained_dots_async_loads
1010

1111
// CHECK: scf.for
12-
// CHECK: rocdl.s.setprio 0
12+
// CHECK-NEXT: rocdl.s.barrier
13+
// CHECK-NEXT: rocdl.sched.barrier 0
1314
// Compute Cluster1
1415
// CHECK: tt.dot
15-
// CHECK: rocdl.s.setprio 1
16-
// CHECK: ttg.async_wait
1716
// CHECK: rocdl.sched.barrier 0
18-
// MemoryCluster2
17+
// CHECK-NEXT: ttg.async_wait
18+
// CHECK-NEXT: rocdl.s.setprio 1
19+
// CHECK-NEXT: rocdl.sched.barrier 0
20+
// Memory Cluster1
1921
// CHECK: ttg.local_load
2022
// CHECK: ttg.async_copy_global_to_local
2123
// CHECK: ttg.async_commit_group
2224
// CHECK: rocdl.sched.barrier 0
23-
// CHECK: rocdl.s.barrier
24-
// CHECK: rocdl.s.setprio 0
25+
// CHECK-NEXT: rocdl.s.setprio 0
26+
// CHECK-NEXT: rocdl.s.waitcnt -7937
27+
// CHECK-NEXT: rocdl.s.barrier
28+
// CHECK-NEXT: rocdl.sched.barrier 0
2529
// Compute Cluster2
2630
// CHECK: tt.dot
27-
// CHECK: rocdl.s.setprio 1
28-
// CHECK: ttg.async_wait
2931
// CHECK: rocdl.sched.barrier 0
32+
// CHECK: ttg.async_wait
33+
// CHECK-NEXT: rocdl.s.setprio 1
34+
// CHECK-NEXT: rocdl.sched.barrier 0
3035
// Memory Cluster2
3136
// CHECK: ttg.local_load
3237
// CHECK: ttg.async_copy_global_to_local
3338
// CHECK: ttg.async_commit_group
3439
// CHECK: rocdl.sched.barrier 0
35-
// CHECK: rocdl.s.barrier
40+
// CHECK-NEXT: rocdl.s.setprio 0
41+
// CHECK-NEXT: rocdl.s.waitcnt -7937
3642
// CHECK-NEXT: scf.yield
3743

3844
tt.func @chained_dots_async_loads(%arg0: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg1: i32, %arg2: i32, %arg3: !ttg.async.token, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {
@@ -76,30 +82,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
7682

7783
// CHECK-NOT: rocdl.s
7884
// CHECK: scf.for
79-
// CHECK: rocdl.s.setprio 0
85+
// CHECK: rocdl.s.barrier
86+
// CHECK-NEXT: rocdl.sched.barrier 0
8087
// Compute Cluster1
8188
// CHECK: tt.dot
82-
// CHECK: rocdl.s.setprio 1
83-
// CHECK: gpu.barrier
8489
// CHECK: rocdl.sched.barrier 0
85-
// MemoryCluster2
90+
// CHECK-NEXT: gpu.barrier
91+
// CHECK-NEXT: rocdl.s.setprio 1
92+
// Memory Cluster1
8693
// CHECK: ttg.local_store
8794
// CHECK: ttg.local_load
8895
// CHECK: tt.load
89-
// CHECK: rocdl.sched.barrier 0
90-
// CHECK: rocdl.s.barrier
91-
// CHECK: rocdl.s.setprio 0
96+
// CHECK-NEXT: rocdl.sched.barrier 0
97+
// CHECK-NEXT: rocdl.s.setprio 0
98+
// CHECK-NEXT: rocdl.s.waitcnt -7937
99+
// CHECK-NEXT: rocdl.s.barrier
100+
// CHECK-NEXT: rocdl.sched.barrier 0
92101
// Compute Cluster2
93102
// CHECK: tt.dot
94-
// CHECK: rocdl.s.setprio 1
95-
// CHECK: gpu.barrier
96103
// CHECK: rocdl.sched.barrier 0
104+
// CHECK-NEXT: gpu.barrier
105+
// CHECK-NEXT: rocdl.s.setprio 1
97106
// Memory Cluster2
98107
// CHECK: ttg.local_store
99108
// CHECK: ttg.local_load
100109
// CHECK: tt.load
101-
// CHECK: rocdl.sched.barrier 0
102-
// CHECK: rocdl.s.barrier
110+
// CHECK-NEXT: rocdl.sched.barrier 0
111+
// CHECK-NEXT: rocdl.s.setprio 0
112+
// CHECK-NEXT: rocdl.s.waitcnt -7937
103113
// CHECK-NEXT: scf.yield
104114

105115
tt.func @chained_dots_tt_loads(%arg0: tensor<64x16xf16, #blocked>, %arg1: tensor<64x16x!tt.ptr<f16>, #blocked>, %arg2: i32, %arg3: i32, %arg4: tensor<128x16xf32, #mma>, %arg5: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>, %arg6: i32, %arg7: tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, %arg8: tensor<128x16xf32, #mma>, %arg9: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg10: tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>, %arg11: i32, %arg12: i32, %arg13: tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) -> tensor<128x16xf32, #mma> {

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,77 @@ LogicalResult Pingponger::transformTwoClusterWithAsyncAndAll(OpBuilder &builder,
670670
// For ChainedDots with num_stage==4 the pipeliner already places ops in the
671671
// correct order to allow for efficient pingpong. The loop contains 2 pairs of
672672
// compute and memory clusters so we only have to place barriers/sched.barriers
673-
// at the bounaries and give higher priority to memory clusters
674-
// See ScheduleLoops.cpp:ChainedDotSchedule for details about the schedule
673+
// at the bounaries and give higher priority to memory clusters.
674+
// See ScheduleLoops.cpp:ChainedDotSchedule for details about the schedule.
675+
//
676+
// Notes
677+
//
678+
// 1. Memory Cluster Priority
679+
// --------------------------
680+
// We assign higher priority to the memory cluster than the compute cluster.
681+
//
682+
// Priority determines which warp issues its next instruction when two warps on
683+
// the same execution unit both have ready instructions of the same type. In
684+
// FAv3, we expect two warps to co-execute — one running the compute cluster,
685+
// and the other running the memory cluster. Both clusters contain `v_xxx`
686+
// (VALU) instructions.
687+
//
688+
// If the compute cluster has higher priority, then its warp will monopolize the
689+
// issue slots for all `v_xxx` instructions, forcing the memory-cluster warp to
690+
// wait. This eliminates the overlap between compute and memory phases — exactly
691+
// what ping-pong scheduling is meant to achieve.
692+
//
693+
// By assigning *higher priority* to the memory cluster, we ensure that the warp
694+
// executing memory instructions can always issue its `v_xxx` operations (for
695+
// address updates) even when another warp is busy in the compute cluster. This
696+
// allows true overlap of memory and compute activity.
697+
//
698+
// This choice does not significantly stall the compute-cluster warp, since the
699+
// memory cluster only contains a few `v_xxx` instructions and its memory ops
700+
// can still co-issue with VALU instructions in the compute cluster.
701+
//
702+
// Note: We currently need this priority scheme because the memory cluster
703+
// contains `v_xxx` instructions for address updates. Ongoing optimizations aim
704+
// to either remove these instructions or move them into the compute cluster,
705+
// which would make this priority adjustment unnecessary.
706+
//
707+
//
708+
// 2. Placement of `s_xxx` Instructions in the Memory Cluster
709+
// ----------------------------------------------------------
710+
// We place scalar (`s_xxx`) instructions in the memory cluster rather than the
711+
// compute cluster.
712+
//
713+
// The reason is that `s_xxx` and `v_xxx` instructions can only co-issue when
714+
// they come from *different warps*. Since compute clusters are dominated by
715+
// VALU instructions, placing `s_xxx` in the memory cluster maximizes co-issue
716+
// opportunities — the scalar instructions from one warp can execute
717+
// concurrently with the VALU instructions from another warp.
718+
//
719+
// Typical `s_xxx` instructions include:
720+
// - Control flow: `s_cbranch`
721+
// - Priority control: `s_setprio`
722+
// - Synchronization and dependency: `s_waitcnt`
723+
//
724+
// These are usually inserted near `s_barrier` boundaries, and the current
725+
// implementation carefully places them to ensure they belong to the memory
726+
// cluster, improving overall overlap and utilization.
727+
//
728+
//
729+
// 3. Placement of `s_waitcnt lgkmcnt(0)`
730+
// --------------------------------------
731+
// We place `s_waitcnt lgkmcnt(0)` at the *end* of the memory cluster to ensure
732+
// that all shared-memory load (`ds_read`) instructions have completed before
733+
// entering the compute cluster.
734+
//
735+
// This placement prevents the LLVM backend from inserting additional
736+
// `s_waitcnt lgkmcnt()` instructions inside the compute cluster based on
737+
// inferred dependencies between `mfma` and `ds_read` operations.
738+
//
739+
// This approach is consistent with the previous design goal: to eliminate all
740+
// `s_xxx` instructions from the compute cluster so it can run uninterrupted
741+
// MFMA and VALU operations. Keeping `s_waitcnt lgkmcnt(0)` at the cluster
742+
// boundary enforces data dependency correctness while preserving the clean
743+
// separation between memory and compute phases.
675744
LogicalResult Pingponger::transformChainedDotSchedule(OpBuilder &builder,
676745
Location loc) {
677746
assert(dotOps.size() == 2);
@@ -697,39 +766,72 @@ LogicalResult Pingponger::transformChainedDotSchedule(OpBuilder &builder,
697766
builder.setInsertionPointToStart(forOp.getBody());
698767
// ComputeCluster 1
699768
updateOpInsertion(dotOps[0]);
700-
prependOp(ROCDL::SetPrioOp::create(builder, loc, lowPriority), false);
769+
prependOp(ROCDL::SBarrierOp::create(builder, loc), false);
770+
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
701771

702772
// MemoryCluster 1
703773
updateOpInsertion(memoryClusterStartOps[0]);
704-
prependOp(ROCDL::SetPrioOp::create(builder, loc, highPriority), false);
774+
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
705775
if (llvm::isa<ttg::AsyncWaitOp>(memoryClusterStartOps[0])) {
706776
// Only append a sched barrier because membar adds a barrier after asyncwait
707777
appendOp(ROCDL::SchedBarrier::create(builder, loc, 0));
708778
} else {
709779
prependOp(gpu::BarrierOp::create(builder, loc), false);
710-
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
711780
}
781+
// Ideally we want the memory cluster to start with
782+
//
783+
// s_barrier
784+
// s_waitcnt vmcnt(x) lgkmcnt(0)
785+
// s_setprio 1
786+
//
787+
// However, the membar pass will put s_waitcnt before s_barrier.
788+
// But we can at least put s_setprio in the memory cluster.
789+
prependOp(ROCDL::SetPrioOp::create(builder, loc, highPriority), false);
712790

713-
// ComputeCluster2
791+
// ComputeCluster 2
792+
// We want the 2nd compute cluster to start with
793+
//
794+
// s_setprio 0
795+
// s_waitcnt lgkmcnt(0)
796+
// s_barrier
797+
//
798+
// Check note 2 and 3 for details.
799+
constexpr int32_t ldsOnlyBits = ~(0x1f << 8);
714800
updateOpInsertion(dotOps[1]);
715801
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
716-
prependOp(ROCDL::SBarrierOp::create(builder, loc), false);
717802
prependOp(ROCDL::SetPrioOp::create(builder, loc, lowPriority), false);
803+
prependOp(ROCDL::SWaitcntOp::create(builder, loc, ldsOnlyBits), false);
804+
prependOp(ROCDL::SBarrierOp::create(builder, loc), false);
805+
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
718806

719807
// MemoryCluster2
720808
updateOpInsertion(memoryClusterStartOps[1]);
721-
prependOp(ROCDL::SetPrioOp::create(builder, loc, highPriority), false);
809+
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
722810
if (llvm::isa<ttg::AsyncWaitOp>(memoryClusterStartOps[1])) {
723811
// Only append a sched barrier because membar adds a barrier after asyncwait
724812
appendOp(ROCDL::SchedBarrier::create(builder, loc, 0));
725813
} else {
726814
prependOp(gpu::BarrierOp::create(builder, loc), false);
727-
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
728815
}
816+
prependOp(ROCDL::SetPrioOp::create(builder, loc, highPriority), false);
729817

818+
// We want the loop to end with the following s.t. s_xxx instructions
819+
// stays in the memory cluster.
820+
//
821+
// s_setprio 0
822+
// s_waitcnt lgkmcnt(0)
823+
// s_cbranch
824+
// s_barrier
825+
//
826+
// Note that we don't insert s_barrier at the end of the loop, since
827+
// the llvm backend may schedule the s_xxx instructions used for
828+
// loop induction variables after the s_barrier and effectively put
829+
// them into the compute cluster. Instead, we insert s_barrier
830+
// at the beginning of the loop.
730831
updateOpInsertion(lastInsertedOp->getBlock()->getTerminator());
731832
prependOp(ROCDL::SchedBarrier::create(builder, loc, 0), false);
732-
prependOp(ROCDL::SBarrierOp::create(builder, loc), false);
833+
prependOp(ROCDL::SetPrioOp::create(builder, loc, lowPriority), false);
834+
prependOp(ROCDL::SWaitcntOp::create(builder, loc, ldsOnlyBits), false);
733835

734836
return success();
735837
}

third_party/amd/lib/TritonAMDGPUTransforms/ScheduleLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ buildSchedule(scf::ForOp &forOp, int numStages, const LoadToInfoMap &loadToInfo,
390390
} // namespace SingleDotSchedule
391391

392392
// Builds a schedule for loops containing chained dots. This schedule aims to
393-
// better interleave mfams with alu ops which can be co-executed on GFX9. It
393+
// better interleave mma with alu ops which can be co-executed on GFX9. It
394394
// works for loops which have 2 dots where the result of the first is
395395
// transformed and used by the second dot. The dot ops will be scheduled with a
396396
// distance of one and the ops in between will be spit into 2 parts. The first

0 commit comments

Comments
 (0)