@@ -167,6 +167,7 @@ class LayoutRematerialization {
167167
168168private:
169169 void updateRematMapping (SmallVector<std::tuple<Value, Value>> &values);
170+ void reduceLoopCarriedValues ();
170171 // Existing tuples of (value, layout) that needs to be updated when recreating
171172 // scf ops. This prevents keeping track of Values that have been delete when
172173 // rewriting slices.
@@ -1009,6 +1010,93 @@ void LayoutRematerialization::updateRematMapping(
10091010 }
10101011}
10111012
1013+ // / Reduce loop carried values if the value is used after the loop and can be
1014+ // / removed by using another loop yielded value plus a convert layout operation.
1015+ void LayoutRematerialization::reduceLoopCarriedValues () {
1016+ for (auto [pair, val] : rematMapping) {
1017+ auto arg = dyn_cast<BlockArgument>(pair.first );
1018+ if (!arg)
1019+ continue ;
1020+
1021+ if (!isTensorPointerType (arg.getType ()))
1022+ continue ;
1023+
1024+ auto loopOp = dyn_cast<LoopLikeOpInterface>(arg.getOwner ()->getParentOp ());
1025+ if (!loopOp)
1026+ continue ;
1027+
1028+ // Loop arguments that corresponds to a loop result which is not used are
1029+ // not interesting.
1030+ OpResult loopRes = loopOp.getTiedLoopResult (arg);
1031+ if (loopRes.getNumUses () == 0 )
1032+ continue ;
1033+
1034+ std::function<void (Operation *, Value)> processUser = [&](Operation *user,
1035+ Value rematRes) {
1036+ Location loc = user->getLoc ();
1037+ OpBuilder rewriter (user);
1038+
1039+ TypeSwitch<Operation *>(user)
1040+ .Case <LoadOp>([&](auto loadOp) {
1041+ auto newLoadOp =
1042+ rewriter.create <LoadOp>(loc, rematRes, loadOp->getAttrs ());
1043+ auto convOp = rewriter.create <ConvertLayoutOp>(
1044+ loc, loadOp.getType (), newLoadOp.getResult ());
1045+ loadOp->replaceAllUsesWith (convOp);
1046+ opToDelete.insert (loadOp);
1047+ LLVM_DEBUG ({
1048+ DBGS () << " Replaced:\n\t " << *loadOp << " \n "
1049+ << " with:\n\t " << *newLoadOp << " \n "
1050+ << " \t " << *convOp << " \n " ;
1051+ });
1052+ })
1053+ .Case <StoreOp>([&](auto storeOp) {
1054+ Value data = storeOp.getOperand (1 );
1055+ auto dataType = cast<RankedTensorType>(data.getType ());
1056+ auto newPtrType = cast<PointerType>(rematRes.getType ());
1057+ Attribute encoding =
1058+ cast<RankedTensorType>(newPtrType.getPointeeType ())
1059+ .getEncoding ();
1060+ RankedTensorType newDataType = dataType.cloneWithEncoding (encoding);
1061+ auto convOp =
1062+ rewriter.create <ConvertLayoutOp>(loc, newDataType, data);
1063+ auto newStoreOp = rewriter.create <StoreOp>(
1064+ loc, rematRes, convOp, storeOp.getBoundaryCheck (),
1065+ storeOp.getCache (), storeOp.getEvict ());
1066+ opToDelete.insert (storeOp);
1067+ LLVM_DEBUG ({
1068+ DBGS () << " Replaced:\n\t " << *storeOp << " \n "
1069+ << " with:\n\t " << *convOp << " \n "
1070+ << " \t " << *newStoreOp << " \n " ;
1071+ });
1072+ })
1073+ .Case <AdvanceOp>([&](auto advanceOp) {
1074+ auto newAdvanceOp = rewriter.create <AdvanceOp>(
1075+ loc, rematRes.getType (), rematRes, advanceOp.getOffsets ());
1076+ opToDelete.insert (advanceOp);
1077+ LLVM_DEBUG ({
1078+ DBGS () << " Replaced:\n\t " << *advanceOp << " \n "
1079+ << " with:\n\t " << *newAdvanceOp << " \n " ;
1080+ });
1081+
1082+ for (Operation *user : advanceOp->getUsers ())
1083+ processUser (user, newAdvanceOp.getResult ());
1084+ })
1085+ .Default ([](auto op) {
1086+ llvm::report_fatal_error (llvm::Twine (
1087+ " Unsupported operation in backward rematerialization: '" +
1088+ op->getName ().getStringRef () + " '" ));
1089+ });
1090+ };
1091+
1092+ // Replace the loop result corresponding to the argument with an
1093+ // equivalent loop result.
1094+ OpResult rematRes = loopOp.getTiedLoopResult (cast<BlockArgument>(val));
1095+ for (Operation *user : loopRes.getUsers ())
1096+ processUser (user, rematRes);
1097+ }
1098+ }
1099+
10121100void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
10131101 DenseMap<Value, Attribute> &layout,
10141102 ConvertLayoutOp convertOp,
@@ -1269,76 +1357,7 @@ void LayoutRematerialization::backwardRematerialization() {
12691357 }
12701358 }
12711359
1272- // Reduce loop carried values if the value can be removed by using another
1273- // loop yielded value plus a convert layout operation.
1274- for (auto [pair, val] : rematMapping) {
1275- auto arg = dyn_cast<BlockArgument>(pair.first );
1276- if (!arg)
1277- continue ;
1278-
1279- if (!isTensorPointerType (arg.getType ()))
1280- continue ;
1281-
1282- if (auto loopOp =
1283- dyn_cast<LoopLikeOpInterface>(arg.getOwner ()->getParentOp ())) {
1284- // Loop arguments that corresponds to a loop result which is not used are
1285- // not interesting.
1286- OpResult loopRes = loopOp.getTiedLoopResult (arg);
1287- if (loopRes.getNumUses () == 0 )
1288- continue ;
1289-
1290- // Replace the loop result corresponding to the argument with an
1291- // equivalent loop result.
1292- auto rematArg = cast<BlockArgument>(val);
1293- OpResult rematRes = loopOp.getTiedLoopResult (rematArg);
1294-
1295- for (Operation *user : loopRes.getUsers ()) {
1296- Location loc = user->getLoc ();
1297- OpBuilder rewriter (user);
1298-
1299- TypeSwitch<Operation *>(user)
1300- .Case <LoadOp>([&](auto loadOp) {
1301- auto newLoadOp =
1302- rewriter.create <LoadOp>(loc, rematRes, loadOp->getAttrs ());
1303- auto convOp = rewriter.create <ConvertLayoutOp>(
1304- loc, loadOp.getType (), newLoadOp.getResult ());
1305- loadOp->replaceAllUsesWith (convOp);
1306- opToDelete.insert (loadOp);
1307- LLVM_DEBUG ({
1308- DBGS () << " Replaced:\n\t " << *loadOp << " \n " ;
1309- DBGS () << " with:\n\t " << *newLoadOp << " \n "
1310- << " \t " << *convOp << " \n " ;
1311- });
1312- })
1313- .Case <StoreOp>([&](auto storeOp) {
1314- Value data = storeOp.getOperand (1 );
1315- auto dataType = cast<RankedTensorType>(data.getType ());
1316- auto newPtrType = cast<PointerType>(rematRes.getType ());
1317- Attribute encoding =
1318- cast<RankedTensorType>(newPtrType.getPointeeType ())
1319- .getEncoding ();
1320- RankedTensorType newDataType =
1321- dataType.cloneWithEncoding (encoding);
1322- auto convOp =
1323- rewriter.create <ConvertLayoutOp>(loc, newDataType, data);
1324- auto newStoreOp = rewriter.create <StoreOp>(
1325- loc, rematRes, convOp, storeOp.getBoundaryCheck (),
1326- storeOp.getCache (), storeOp.getEvict ());
1327- opToDelete.insert (storeOp);
1328- LLVM_DEBUG ({
1329- DBGS () << " Replaced:\n\t " << *storeOp << " \n " ;
1330- DBGS () << " with:\n\t " << *convOp << " \n "
1331- << " \t " << *newStoreOp << " \n " ;
1332- });
1333- })
1334- .Default ([](auto op) {
1335- llvm::report_fatal_error (llvm::Twine (
1336- " Unsupported operation in backward rematerialization: '" +
1337- op->getName ().getStringRef () + " '" ));
1338- });
1339- }
1340- }
1341- }
1360+ reduceLoopCarriedValues ();
13421361}
13431362
13441363void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast () {
0 commit comments