@@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
11251125 return success (*map != initialMap);
11261126}
11271127
1128+ // / Recursively traverse `e`. If `e` or one of its sub-expressions has the form
1129+ // / e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
1130+ // / place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
1131+ // / into `replacementsMap`. If no entries were added to `replacementsMap`,
1132+ // / nothing was found.
1133+ static void shortenAddChainsContainingAll (
1134+ AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4 > &exprsToRemove,
1135+ AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
1136+ auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1137+ if (!binOp)
1138+ return ;
1139+ AffineExpr lhs = binOp.getLHS ();
1140+ AffineExpr rhs = binOp.getRHS ();
1141+ if (binOp.getKind () != AffineExprKind::Add) {
1142+ shortenAddChainsContainingAll (lhs, exprsToRemove, newVal, replacementsMap);
1143+ shortenAddChainsContainingAll (rhs, exprsToRemove, newVal, replacementsMap);
1144+ return ;
1145+ }
1146+ SmallVector<AffineExpr> toPreserve;
1147+ llvm::SmallDenseSet<AffineExpr, 4 > ourTracker (exprsToRemove);
1148+ AffineExpr thisTerm = rhs;
1149+ AffineExpr nextTerm = lhs;
1150+
1151+ while (thisTerm) {
1152+ if (!ourTracker.erase (thisTerm)) {
1153+ toPreserve.push_back (thisTerm);
1154+ shortenAddChainsContainingAll (thisTerm, exprsToRemove, newVal,
1155+ replacementsMap);
1156+ }
1157+ auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1158+ if (!nextBinOp || nextBinOp.getKind () != AffineExprKind::Add) {
1159+ thisTerm = nextTerm;
1160+ nextTerm = AffineExpr ();
1161+ } else {
1162+ thisTerm = nextBinOp.getRHS ();
1163+ nextTerm = nextBinOp.getLHS ();
1164+ }
1165+ }
1166+ if (!ourTracker.empty ())
1167+ return ;
1168+ // We reverse the terms to be preserved here in order to preserve
1169+ // associativity between them.
1170+ AffineExpr newExpr = newVal;
1171+ for (AffineExpr preserved : llvm::reverse (toPreserve))
1172+ newExpr = newExpr + preserved;
1173+ replacementsMap.insert ({e, newExpr});
1174+ }
1175+
1176+ // / If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
1177+ // / ...` (not necessarily in order) where the set of the `x_i` is the set of
1178+ // / outputs of an `affine.delinearize_index` whos inverse is that expression,
1179+ // / replace that expression with the input of that delinearize_index op.
1180+ // /
1181+ // / `unitDimInput` is the input that was detected as the potential start to this
1182+ // / replacement chain - if it isn't the rightmost result of the delinearization,
1183+ // / this method fails. (This is intended to ensure we don't have redundant scans
1184+ // / over the same expression).
1185+ // /
1186+ // / While this currently only handles delinearizations with a constant basis,
1187+ // / that isn't a fundamental limitation.
1188+ // /
1189+ // / This is a utility function for `replaceDimOrSym` below.
1190+ static LogicalResult replaceAffineDelinearizeIndexInverseExpression (
1191+ AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
1192+ SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
1193+ if (!delinOp.getDynamicBasis ().empty ())
1194+ return failure ();
1195+ if (resultToReplace != delinOp.getMultiIndex ().back ())
1196+ return failure ();
1197+
1198+ MLIRContext *ctx = delinOp.getContext ();
1199+ SmallVector<AffineExpr> resToExpr (delinOp.getNumResults (), AffineExpr ());
1200+ for (auto [pos, dim] : llvm::enumerate (dims)) {
1201+ auto asResult = dyn_cast_if_present<OpResult>(dim);
1202+ if (!asResult)
1203+ continue ;
1204+ if (asResult.getOwner () == delinOp.getOperation ())
1205+ resToExpr[asResult.getResultNumber ()] = getAffineDimExpr (pos, ctx);
1206+ }
1207+ for (auto [pos, sym] : llvm::enumerate (syms)) {
1208+ auto asResult = dyn_cast_if_present<OpResult>(sym);
1209+ if (!asResult)
1210+ continue ;
1211+ if (asResult.getOwner () == delinOp.getOperation ())
1212+ resToExpr[asResult.getResultNumber ()] = getAffineSymbolExpr (pos, ctx);
1213+ }
1214+ if (llvm::is_contained (resToExpr, AffineExpr ()))
1215+ return failure ();
1216+
1217+ bool isDimReplacement = llvm::all_of (resToExpr, llvm::IsaPred<AffineDimExpr>);
1218+ int64_t stride = 1 ;
1219+ llvm::SmallDenseSet<AffineExpr, 4 > expectedExprs;
1220+ // This isn't zip_equal since sometimes the delinearize basis is missing a
1221+ // size for the first result.
1222+ for (auto [binding, size] : llvm::zip (
1223+ llvm::reverse (resToExpr), llvm::reverse (delinOp.getStaticBasis ()))) {
1224+ expectedExprs.insert (binding * getAffineConstantExpr (stride, ctx));
1225+ stride *= size;
1226+ }
1227+ if (resToExpr.size () != delinOp.getStaticBasis ().size ())
1228+ expectedExprs.insert (resToExpr[0 ] * stride);
1229+
1230+ DenseMap<AffineExpr, AffineExpr> replacements;
1231+ AffineExpr delinInExpr = isDimReplacement
1232+ ? getAffineDimExpr (dims.size (), ctx)
1233+ : getAffineSymbolExpr (syms.size (), ctx);
1234+
1235+ for (AffineExpr e : map->getResults ())
1236+ shortenAddChainsContainingAll (e, expectedExprs, delinInExpr, replacements);
1237+ if (replacements.empty ())
1238+ return failure ();
1239+
1240+ AffineMap origMap = *map;
1241+ if (isDimReplacement)
1242+ dims.push_back (delinOp.getLinearIndex ());
1243+ else
1244+ syms.push_back (delinOp.getLinearIndex ());
1245+ *map = origMap.replace (replacements, dims.size (), syms.size ());
1246+
1247+ // Blank out dead dimensions and symbols
1248+ for (AffineExpr e : resToExpr) {
1249+ if (auto d = dyn_cast<AffineDimExpr>(e)) {
1250+ unsigned pos = d.getPosition ();
1251+ if (!map->isFunctionOfDim (pos))
1252+ dims[pos] = nullptr ;
1253+ }
1254+ if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255+ unsigned pos = s.getPosition ();
1256+ if (!map->isFunctionOfSymbol (pos))
1257+ syms[pos] = nullptr ;
1258+ }
1259+ }
1260+ return success ();
1261+ }
1262+
11281263// / Replace all occurrences of AffineExpr at position `pos` in `map` by the
11291264// / defining AffineApplyOp expression and operands.
11301265// / When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
11571292 syms);
11581293 }
11591294
1295+ if (auto delinOp = v.getDefiningOp <affine::AffineDelinearizeIndexOp>()) {
1296+ return replaceAffineDelinearizeIndexInverseExpression (delinOp, v, map, dims,
1297+ syms);
1298+ }
1299+
11601300 auto affineApply = v.getDefiningOp <AffineApplyOp>();
11611301 if (!affineApply)
11621302 return failure ();
0 commit comments