Skip to content

Commit 250b6eb

Browse files
[AMD] Enable Pingpong by default on gfx950 arch (#7697)
List of enabling conditions - FP/BF16 GEMM with M,N>64 tilesize when num_stages=3 and num_warps=8 - GEMM using `dot_scaled` with M=N=256 tile size when num_stages=2 and num_warps=8 - FA with num_stages=4 Only with using async_copy.
1 parent cf0db92 commit 250b6eb

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

third_party/amd/backend/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def get_min_dot_size(target: GPUTarget):
1818
return lambda lhs_type, rhs_type: (1, 1, 1)
1919

2020

21-
def is_pingpong_schedule_enabled(arch):
22-
return (arch == "gfx942") if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
21+
def is_pingpong_schedule_enabled(arch, use_async_copy):
22+
return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
23+
) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
2324

2425

2526
def is_in_thread_transpose_enabled(arch):
@@ -218,7 +219,7 @@ def make_ttgir(mod, metadata, options):
218219
global_prefetch = knobs.amd.global_prefetch
219220
local_prefetch = knobs.amd.local_prefetch
220221
use_async_copy = knobs.amd.use_async_copy
221-
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
222+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
222223

223224
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
224225
use_block_pingpong)

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ Pingponger::transformTwoClusterWithLocalLoadAndAll(OpBuilder &builder,
763763
asyncWaitOp->erase();
764764
}
765765
}
766+
assert(newAsyncWaitOp != nullptr);
766767

767768
moveOpAndPredecessorsUpSameBlock(lLoadOps[0]);
768769
moveOpAndPredecessorsUpSameBlock(lLoadOps[1]);
@@ -917,21 +918,18 @@ void Pingponger::getDotPingponged() {
917918
auto aType = scaledDotOps[0].getA().getType();
918919
auto aShape = aType.getShape();
919920
auto elemWidth = aType.getElementTypeBitWidth();
920-
int64_t tileSize = scaledDotShape[0] * scaledDotShape[1] * aShape[1];
921921

922-
// 256x256x256 (128xi8)
923-
if (tileSize == 8388608 && aShape[0] == 256 && aShape[1] == 128 &&
922+
// MxN = 256x256
923+
if (scaledDotShape[0] == 256 && scaledDotShape[1] == 256 &&
924924
elemWidth == 8) {
925-
kWidth = 16;
926925
if (transformTwoClusterWithAsyncAndAll(builder, scaledDotOps[0]->getLoc())
927926
.failed()) {
928-
LDBG(
929-
"Encountered failure when trying to execute the two-step ping pong "
930-
"cluster transformation");
927+
LDBG("Encountered failure when trying to execute the"
928+
"TwoClusterWithAsyncAndAll transformation");
931929
return;
932930
}
931+
addAsymmetricSyncToLoop(builder, loc);
933932
}
934-
addAsymmetricSyncToLoop(builder, loc);
935933
return;
936934
} else if (scaledDotOps.size() == 1)
937935
return;
@@ -941,7 +939,6 @@ void Pingponger::getDotPingponged() {
941939
// Determine if we have a persistent GEMM. This will decide how we interpret
942940
// any memory operations that we find in conditionals.
943941
auto assumeNotTaken = isPersistentGemm(dotOps.size());
944-
945942
// Compute tile size, kWidth, and mfma type.
946943
auto dotType = dotOps[0].getType();
947944
auto dotShape = dotType.getShape();
@@ -968,11 +965,11 @@ void Pingponger::getDotPingponged() {
968965
LDBG("Currently only support num_warp=8 for async PP");
969966
return;
970967
}
971-
if (numStages > 2 && dotOps.size() == 1 && tileSize == mediumTile &&
972-
aShape[1] == 32 && elemWidth == 16) {
968+
if (numStages > 2 && dotOps.size() == 1 && dotShape[0] > 64 &&
969+
dotShape[1] > 64 && (elemWidth == 16 || elemWidth == 8)) {
973970
if (transformTwoClusterWithLocalLoadAndAll(builder, loc).failed()) {
974-
LDBG("Encountered failure when trying to execute the NS3 ping pong "
975-
"cluster transformation");
971+
LDBG("Encountered failure when trying to execute the "
972+
"TwoClusterWithLocalLoadAndAll transformation");
976973
return;
977974
}
978975
addAsymmetricSyncToLoop(builder, loc);

0 commit comments

Comments
 (0)