Skip to content

Commit b204a07

Browse files
jungpark-mlirAlexAUT
authored andcommitted
[AMD] Enable Pingpong by default on gfx950 arch (triton-lang#7697)
List of enabling conditions - FP/BF16 GEMM with M,N>64 tilesize when num_stages=3 and num_warps=8 - GEMM using `dot_scaled` with M=N=256 tile size when num_stages=2 and num_warps=8 - FA with num_stages=4 Only with using async_copy.
1 parent 087e624 commit b204a07

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

third_party/amd/backend/compiler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def get_min_dot_size(target: GPUTarget):
1818
return lambda lhs_type, rhs_type: (1, 1, 1)
1919

2020

21-
def is_pingpong_schedule_enabled(arch):
22-
return (arch == "gfx942") if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
21+
def is_pingpong_schedule_enabled(arch, use_async_copy):
22+
return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
23+
) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
2324

2425

2526
def is_in_thread_transpose_enabled(arch):
@@ -233,7 +234,7 @@ def make_ttgir(mod, metadata, options):
233234
global_prefetch = knobs.amd.global_prefetch
234235
local_prefetch = knobs.amd.local_prefetch
235236
use_async_copy = knobs.amd.use_async_copy
236-
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
237+
use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)
237238

238239
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
239240
use_block_pingpong)

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,7 @@ Pingponger::transformTwoClusterWithLocalLoadAndAll(OpBuilder &builder,
764764
asyncWaitOp->erase();
765765
}
766766
}
767+
assert(newAsyncWaitOp != nullptr);
767768

768769
moveOpAndPredecessorsUpSameBlock(lLoadOps[0]);
769770
moveOpAndPredecessorsUpSameBlock(lLoadOps[1]);
@@ -918,21 +919,18 @@ void Pingponger::getDotPingponged() {
918919
auto aType = scaledDotOps[0].getA().getType();
919920
auto aShape = aType.getShape();
920921
auto elemWidth = aType.getElementTypeBitWidth();
921-
int64_t tileSize = scaledDotShape[0] * scaledDotShape[1] * aShape[1];
922922

923-
// 256x256x256 (128xi8)
924-
if (tileSize == 8388608 && aShape[0] == 256 && aShape[1] == 128 &&
923+
// MxN = 256x256
924+
if (scaledDotShape[0] == 256 && scaledDotShape[1] == 256 &&
925925
elemWidth == 8) {
926-
kWidth = 16;
927926
if (transformTwoClusterWithAsyncAndAll(builder, scaledDotOps[0]->getLoc())
928927
.failed()) {
929-
LDBG(
930-
"Encountered failure when trying to execute the two-step ping pong "
931-
"cluster transformation");
928+
LDBG("Encountered failure when trying to execute the"
929+
"TwoClusterWithAsyncAndAll transformation");
932930
return;
933931
}
932+
addAsymmetricSyncToLoop(builder, loc);
934933
}
935-
addAsymmetricSyncToLoop(builder, loc);
936934
return;
937935
} else if (scaledDotOps.size() == 1)
938936
return;
@@ -942,7 +940,6 @@ void Pingponger::getDotPingponged() {
942940
// Determine if we have a persistent GEMM. This will decide how we interpret
943941
// any memory operations that we find in conditionals.
944942
auto assumeNotTaken = isPersistentGemm(dotOps.size());
945-
946943
// Compute tile size, kWidth, and mfma type.
947944
auto dotType = dotOps[0].getType();
948945
auto dotShape = dotType.getShape();
@@ -969,11 +966,11 @@ void Pingponger::getDotPingponged() {
969966
LDBG("Currently only support num_warp=8 for async PP");
970967
return;
971968
}
972-
if (numStages > 2 && dotOps.size() == 1 && tileSize == mediumTile &&
973-
aShape[1] == 32 && elemWidth == 16) {
969+
if (numStages > 2 && dotOps.size() == 1 && dotShape[0] > 64 &&
970+
dotShape[1] > 64 && (elemWidth == 16 || elemWidth == 8)) {
974971
if (transformTwoClusterWithLocalLoadAndAll(builder, loc).failed()) {
975-
LDBG("Encountered failure when trying to execute the NS3 ping pong "
976-
"cluster transformation");
972+
LDBG("Encountered failure when trying to execute the "
973+
"TwoClusterWithLocalLoadAndAll transformation");
977974
return;
978975
}
979976
addAsymmetricSyncToLoop(builder, loc);

0 commit comments

Comments
 (0)