@@ -763,6 +763,7 @@ Pingponger::transformTwoClusterWithLocalLoadAndAll(OpBuilder &builder,
763
763
asyncWaitOp->erase ();
764
764
}
765
765
}
766
+ assert (newAsyncWaitOp != nullptr );
766
767
767
768
moveOpAndPredecessorsUpSameBlock (lLoadOps[0 ]);
768
769
moveOpAndPredecessorsUpSameBlock (lLoadOps[1 ]);
@@ -917,21 +918,18 @@ void Pingponger::getDotPingponged() {
917
918
auto aType = scaledDotOps[0 ].getA ().getType ();
918
919
auto aShape = aType.getShape ();
919
920
auto elemWidth = aType.getElementTypeBitWidth ();
920
- int64_t tileSize = scaledDotShape[0 ] * scaledDotShape[1 ] * aShape[1 ];
921
921
922
- // 256x256x256 (128xi8)
923
- if (tileSize == 8388608 && aShape [0 ] == 256 && aShape [1 ] == 128 &&
922
+ // MxN = 256x256
923
+ if (scaledDotShape [0 ] == 256 && scaledDotShape [1 ] == 256 &&
924
924
elemWidth == 8 ) {
925
- kWidth = 16 ;
926
925
if (transformTwoClusterWithAsyncAndAll (builder, scaledDotOps[0 ]->getLoc ())
927
926
.failed ()) {
928
- LDBG (
929
- " Encountered failure when trying to execute the two-step ping pong "
930
- " cluster transformation" );
927
+ LDBG (" Encountered failure when trying to execute the"
928
+ " TwoClusterWithAsyncAndAll transformation" );
931
929
return ;
932
930
}
931
+ addAsymmetricSyncToLoop (builder, loc);
933
932
}
934
- addAsymmetricSyncToLoop (builder, loc);
935
933
return ;
936
934
} else if (scaledDotOps.size () == 1 )
937
935
return ;
@@ -941,7 +939,6 @@ void Pingponger::getDotPingponged() {
941
939
// Determine if we have a persistent GEMM. This will decide how we interpret
942
940
// any memory operations that we find in conditionals.
943
941
auto assumeNotTaken = isPersistentGemm (dotOps.size ());
944
-
945
942
// Compute tile size, kWidth, and mfma type.
946
943
auto dotType = dotOps[0 ].getType ();
947
944
auto dotShape = dotType.getShape ();
@@ -968,11 +965,11 @@ void Pingponger::getDotPingponged() {
968
965
LDBG (" Currently only support num_warp=8 for async PP" );
969
966
return ;
970
967
}
971
- if (numStages > 2 && dotOps.size () == 1 && tileSize == mediumTile &&
972
- aShape [1 ] == 32 && elemWidth == 16 ) {
968
+ if (numStages > 2 && dotOps.size () == 1 && dotShape[ 0 ] > 64 &&
969
+ dotShape [1 ] > 64 && ( elemWidth == 16 || elemWidth == 8 ) ) {
973
970
if (transformTwoClusterWithLocalLoadAndAll (builder, loc).failed ()) {
974
- LDBG (" Encountered failure when trying to execute the NS3 ping pong "
975
- " cluster transformation" );
971
+ LDBG (" Encountered failure when trying to execute the "
972
+ " TwoClusterWithLocalLoadAndAll transformation" );
976
973
return ;
977
974
}
978
975
addAsymmetricSyncToLoop (builder, loc);
0 commit comments