2626#include " mlir/Support/LLVM.h"
2727#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2828#include " mlir/Transforms/RegionUtils.h"
29+ #include " llvm/ADT/STLExtras.h"
30+ #include " llvm/Support/LogicalResult.h"
2931#include < optional>
3032#include < utility>
3133
@@ -1100,6 +1102,20 @@ class FoldPadWithProducerReshapeOpByExpansion
11001102 ControlFusionFn controlFoldingReshapes;
11011103};
11021104
1105+ bool isZero (OpFoldResult value) {
1106+ if (auto attr = dyn_cast<Attribute>(value)) {
1107+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1108+ return intAttr.getInt () == 0 ;
1109+ }
1110+ if (auto val = dyn_cast<Value>(value)) {
1111+ if (auto constOp = val.getDefiningOp <arith::ConstantOp>()) {
1112+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue ()))
1113+ return attr.getInt () == 0 ;
1114+ }
1115+ }
1116+ return false ;
1117+ }
1118+
11031119// / Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
11041120// / by bubbling the expand_shape before the pad.
11051121struct FoldReshapeWithProducerPadOpByExpansion
@@ -1125,41 +1141,29 @@ struct FoldReshapeWithProducerPadOpByExpansion
11251141 " fusion blocked by control function" );
11261142 }
11271143
1144+ Value constantPaddingValue = padOp.getConstantPaddingValue ();
1145+ if (!constantPaddingValue) {
1146+ return rewriter.notifyMatchFailure (
1147+ expandOp, " cannot fold with non-constant padding value" );
1148+ }
1149+
11281150 SmallVector<ReassociationIndices> reassociations =
11291151 expandOp.getReassociationIndices ();
11301152 SmallVector<OpFoldResult> low = padOp.getMixedLowPad ();
11311153 SmallVector<OpFoldResult> high = padOp.getMixedHighPad ();
11321154
1133- auto isZeroPadding = [](OpFoldResult padValue) -> bool {
1134- if (auto attr = dyn_cast<Attribute>(padValue)) {
1135- if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1136- return intAttr.getInt () == 0 ;
1137- }
1138-
1139- if (auto val = dyn_cast<Value>(padValue)) {
1140- if (auto constOp = val.getDefiningOp <arith::ConstantOp>()) {
1141- if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue ()))
1142- return attr.getInt () == 0 ;
1143- }
1144- }
1145-
1146- // when padding is dynamic and not constant, we don't know if it's zero or
1147- // not. so we return false here.
1148- return false ;
1149- };
1150-
11511155 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
11521156 OpFoldResult l = low[idx];
11531157 OpFoldResult h = high[idx];
1154- if (reInd.size () != 1 && (!isZeroPadding (l) || !isZeroPadding (h)))
1158+ if (reInd.size () > 1 && (!isZero (l) || !isZero (h)))
11551159 return failure ();
11561160 }
11571161
11581162 SmallVector<OpFoldResult> newLow, newHigh;
11591163 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
11601164 for (size_t i = 0 ; i < reInd.size (); ++i) {
1161- newLow.push_back (padOp. getMixedLowPad () [idx]);
1162- newHigh.push_back (padOp. getMixedHighPad () [idx]);
1165+ newLow.push_back (low [idx]);
1166+ newHigh.push_back (high [idx]);
11631167 }
11641168 }
11651169
@@ -1176,11 +1180,11 @@ struct FoldReshapeWithProducerPadOpByExpansion
11761180 }
11771181 }
11781182
1179- for (auto [inDimIdx, outGroup ] : llvm::enumerate (reassociations)) {
1183+ for (auto [inDimIdx, reInd ] : llvm::enumerate (reassociations)) {
11801184 OpFoldResult l = low[inDimIdx];
11811185 OpFoldResult h = high[inDimIdx];
11821186
1183- if (!isZeroPadding (l) || !isZeroPadding (h)) {
1187+ if (!isZero (l) || !isZero (h)) {
11841188 auto srcType = cast<RankedTensorType>(padOp.getSource ().getType ());
11851189 int64_t originalSize = srcType.getDimSize (inDimIdx);
11861190
@@ -1193,7 +1197,7 @@ struct FoldReshapeWithProducerPadOpByExpansion
11931197 originalSizeOFR = rewriter.getI64IntegerAttr (originalSize);
11941198 }
11951199
1196- for (auto outDimIdx : outGroup ) {
1200+ for (auto outDimIdx : reInd ) {
11971201 expandedShape[outDimIdx] = originalSizeOFR;
11981202 }
11991203 }
@@ -1240,6 +1244,125 @@ struct FoldReshapeWithProducerPadOpByExpansion
12401244 ControlFusionFn controlFoldingReshapes;
12411245};
12421246
1247+ // / Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
1248+ // / by bubbling the collapse_shape before the pad.
1249+ struct FoldReshapeWithProducerPadOpByCollapsing
1250+ : public OpRewritePattern<tensor::CollapseShapeOp> {
1251+
1252+ FoldReshapeWithProducerPadOpByCollapsing (MLIRContext *context,
1253+ ControlFusionFn foldReshapes,
1254+ PatternBenefit benefit = 1 )
1255+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
1256+ controlFoldingReshapes (std::move(foldReshapes)) {}
1257+
1258+ LogicalResult matchAndRewrite (tensor::CollapseShapeOp collapseOp,
1259+ PatternRewriter &rewriter) const override {
1260+ tensor::PadOp padOp = collapseOp.getSrc ().getDefiningOp <tensor::PadOp>();
1261+
1262+ if (!padOp)
1263+ return failure ();
1264+
1265+ if (!padOp->hasOneUse ())
1266+ return failure ();
1267+
1268+ if (!controlFoldingReshapes (&collapseOp.getSrcMutable ())) {
1269+ return rewriter.notifyMatchFailure (collapseOp,
1270+ " fusion blocked by control function" );
1271+ }
1272+
1273+ Value constantPaddingValue = padOp.getConstantPaddingValue ();
1274+ if (!constantPaddingValue) {
1275+ return rewriter.notifyMatchFailure (
1276+ collapseOp, " cannot fold with non-constant padding value" );
1277+ }
1278+
1279+ SmallVector<ReassociationIndices> reassociations =
1280+ collapseOp.getReassociationIndices ();
1281+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad ();
1282+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad ();
1283+
1284+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1285+ if (reInd.size () > 1 ) {
1286+ for (auto dimIdx : reInd) {
1287+ if (!isZero (low[dimIdx]) || !isZero (high[dimIdx])) {
1288+ return failure ();
1289+ }
1290+ }
1291+ }
1292+ }
1293+
1294+ SmallVector<OpFoldResult> newLow, newHigh;
1295+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1296+ newLow.push_back (low[reInd[0 ]]);
1297+ newHigh.push_back (high[reInd[0 ]]);
1298+ }
1299+
1300+ Location loc = collapseOp.getLoc ();
1301+ auto resultType = collapseOp.getResultType ();
1302+
1303+ auto finalType = cast<RankedTensorType>(collapseOp.getType ());
1304+ ArrayRef<int64_t > finalShape = finalType.getShape ();
1305+
1306+ SmallVector<OpFoldResult> collapsedShape;
1307+ for (int64_t dimSize : finalShape) {
1308+ if (dimSize == ShapedType::kDynamic ) {
1309+ collapsedShape.push_back (OpFoldResult{});
1310+ } else {
1311+ collapsedShape.push_back (rewriter.getI64IntegerAttr (dimSize));
1312+ }
1313+ }
1314+
1315+ for (auto [inDimIdx, reInd] : llvm::enumerate (reassociations)) {
1316+ OpFoldResult l = low[reInd[0 ]];
1317+ OpFoldResult h = high[reInd[0 ]];
1318+
1319+ if (!isZero (l) || !isZero (h)) {
1320+ auto srcType = cast<RankedTensorType>(padOp.getSource ().getType ());
1321+ int64_t originalSize = srcType.getDimSize (reInd[0 ]);
1322+
1323+ OpFoldResult originalSizeOFR;
1324+ if (originalSize == ShapedType::kDynamic ) {
1325+ Value orgSizeVal =
1326+ rewriter.create <tensor::DimOp>(loc, padOp.getSource (), reInd[0 ]);
1327+ originalSizeOFR = orgSizeVal;
1328+ } else {
1329+ originalSizeOFR = rewriter.getI64IntegerAttr (originalSize);
1330+ }
1331+ collapsedShape[inDimIdx] = originalSizeOFR;
1332+ }
1333+ }
1334+
1335+ SmallVector<int64_t > staticCollapsedShape;
1336+ for (OpFoldResult dim : collapsedShape) {
1337+ if (auto attr = dyn_cast<Attribute>(dim)) {
1338+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
1339+ staticCollapsedShape.push_back (intAttr.getInt ());
1340+ } else {
1341+ staticCollapsedShape.push_back (ShapedType::kDynamic );
1342+ }
1343+ } else {
1344+ staticCollapsedShape.push_back (ShapedType::kDynamic );
1345+ }
1346+ }
1347+
1348+ auto newCollapseType = RankedTensorType::get (
1349+ staticCollapsedShape, padOp.getSource ().getType ().getElementType ());
1350+ auto newCollapseOp = rewriter.create <tensor::CollapseShapeOp>(
1351+ loc, newCollapseType, padOp.getSource (), reassociations);
1352+
1353+ auto newPadOp = rewriter.create <tensor::PadOp>(
1354+ loc, resultType, newCollapseOp.getResult (), newLow, newHigh,
1355+ padOp.getConstantPaddingValue (), padOp.getNofold ());
1356+
1357+ rewriter.replaceOp (collapseOp, newPadOp.getResult ());
1358+
1359+ return success ();
1360+ }
1361+
1362+ private:
1363+ ControlFusionFn controlFoldingReshapes;
1364+ };
1365+
12431366// / Pattern to fold a tensor.expand_shape op with its producer generic op
12441367// / by expanding the dimensionality of the loop in the producer op.
12451368struct FoldReshapeWithGenericOpByExpansion
@@ -2388,6 +2511,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
23882511 controlFoldingReshapes);
23892512 patterns.add <FoldPadWithProducerReshapeOpByCollapsing>(
23902513 patterns.getContext (), controlFoldingReshapes);
2514+ patterns.add <FoldReshapeWithProducerPadOpByCollapsing>(
2515+ patterns.getContext (), controlFoldingReshapes);
2516+
2517+ patterns.add <FoldReshapeWithProducerPadOpByCollapsing>(
2518+ patterns.getContext (), controlFoldingReshapes);
23912519 patterns.add <FoldReshapeWithGenericOpByCollapsing>(patterns.getContext (),
23922520 controlFoldingReshapes);
23932521}
0 commit comments