Skip to content

Commit 03ff5d2

Browse files
jungpark-mlirmakslevental
authored andcommitted
[AMD] Enable block pingpong for smaller tiles (triton-lang#5820)
Recent experiment found it also helps few more configs especially smaller tiles. Enable one cluster pingpong for the 4 times smaller tiles.
1 parent 1e6bae1 commit 03ff5d2

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ void Pingponger::getDotPingponged() {
409409
// software pipelining and dot rank=2. Also only accept the for-loop with
410410
// supported combination of operations because this transformation is very
411411
// tightly scheduling the latencies.
412-
if (gLoadOps.size() != 2 || lLoadOps.size() != 2 || dotOps.size() != 1)
412+
if (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotOps.size() != 1)
413413
return;
414414

415415
// Pingpong scheduling tries to form two different types of the instruction
@@ -447,6 +447,7 @@ void Pingponger::getDotPingponged() {
447447
auto elemWidth = aType.getElementTypeBitWidth();
448448
int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth;
449449

450+
const int64_t minTile = 262144; // e.g. 32x128x64x16bit
450451
const int64_t smallTile = 16777216; // e.g. 128x128x64x16bit
451452
const int64_t mediumTile = 33554432; // smallTile x 2
452453
const int64_t largeTile = 67108864; // e.g. 256x256x64x16bit
@@ -465,11 +466,13 @@ void Pingponger::getDotPingponged() {
465466
// times for issuing the memory operations and issuing dot operations,
466467
// smaller tile sizes are not likely to get any advantage from current dot
467468
// centric pingpong scheduling.
468-
if (tileSize == smallTile)
469+
if (tileSize <= smallTile && tileSize >= minTile)
469470
transformOnePPClusters(builder, loc);
470471
// numWarps=4 doesn't need asymmetric sync, return.
471472
return;
472473
} else if (numWarps == 8) { // Pingpong between warps from the same block
474+
if (gLoadOps.size() != 2 || lLoadOps.size() != 2)
475+
return;
473476
// Transform a loop where the tile size requires dots to be sliced
474477
if (tileSize == mediumTile) {
475478
if (transformTwoPPClusters(builder, dotOps[0]->getLoc()).failed())

0 commit comments

Comments
 (0)