@@ -764,6 +764,7 @@ Pingponger::transformTwoClusterWithLocalLoadAndAll(OpBuilder &builder,
764764 asyncWaitOp->erase ();
765765 }
766766 }
767+ assert (newAsyncWaitOp != nullptr );
767768
768769 moveOpAndPredecessorsUpSameBlock (lLoadOps[0 ]);
769770 moveOpAndPredecessorsUpSameBlock (lLoadOps[1 ]);
@@ -918,21 +919,18 @@ void Pingponger::getDotPingponged() {
918919 auto aType = scaledDotOps[0 ].getA ().getType ();
919920 auto aShape = aType.getShape ();
920921 auto elemWidth = aType.getElementTypeBitWidth ();
921- int64_t tileSize = scaledDotShape[0 ] * scaledDotShape[1 ] * aShape[1 ];
922922
923- // 256x256x256 (128xi8)
924- if (tileSize == 8388608 && aShape [0 ] == 256 && aShape [1 ] == 128 &&
923+ // MxN = 256x256
924+ if (scaledDotShape [0 ] == 256 && scaledDotShape [1 ] == 256 &&
925925 elemWidth == 8 ) {
926- kWidth = 16 ;
927926 if (transformTwoClusterWithAsyncAndAll (builder, scaledDotOps[0 ]->getLoc ())
928927 .failed ()) {
929- LDBG (
930- " Encountered failure when trying to execute the two-step ping pong "
931- " cluster transformation" );
928+ LDBG (" Encountered failure when trying to execute the"
929+ " TwoClusterWithAsyncAndAll transformation" );
932930 return ;
933931 }
932+ addAsymmetricSyncToLoop (builder, loc);
934933 }
935- addAsymmetricSyncToLoop (builder, loc);
936934 return ;
937935 } else if (scaledDotOps.size () == 1 )
938936 return ;
@@ -942,7 +940,6 @@ void Pingponger::getDotPingponged() {
942940 // Determine if we have a persistent GEMM. This will decide how we interpret
943941 // any memory operations that we find in conditionals.
944942 auto assumeNotTaken = isPersistentGemm (dotOps.size ());
945-
946943 // Compute tile size, kWidth, and mfma type.
947944 auto dotType = dotOps[0 ].getType ();
948945 auto dotShape = dotType.getShape ();
@@ -969,11 +966,11 @@ void Pingponger::getDotPingponged() {
969966 LDBG (" Currently only support num_warp=8 for async PP" );
970967 return ;
971968 }
972- if (numStages > 2 && dotOps.size () == 1 && tileSize == mediumTile &&
973- aShape [1 ] == 32 && elemWidth == 16 ) {
969+ if (numStages > 2 && dotOps.size () == 1 && dotShape[ 0 ] > 64 &&
970+ dotShape [1 ] > 64 && ( elemWidth == 16 || elemWidth == 8 ) ) {
974971 if (transformTwoClusterWithLocalLoadAndAll (builder, loc).failed ()) {
975- LDBG (" Encountered failure when trying to execute the NS3 ping pong "
976- " cluster transformation" );
972+ LDBG (" Encountered failure when trying to execute the "
973+ " TwoClusterWithLocalLoadAndAll transformation" );
977974 return ;
978975 }
979976 addAsymmetricSyncToLoop (builder, loc);
0 commit comments