@@ -1101,6 +1101,146 @@ class FoldPadWithProducerReshapeOpByExpansion
11011101 ControlFusionFn controlFoldingReshapes;
11021102};
11031103
1104+ // / Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
1105+ // / by bubbling the expand_shape before the pad.
1106+ struct FoldReshapeWithProducerPadOpByExpansion
1107+ : public OpRewritePattern<tensor::ExpandShapeOp> {
1108+
1109+ FoldReshapeWithProducerPadOpByExpansion (MLIRContext *context,
1110+ ControlFusionFn foldReshapes,
1111+ PatternBenefit benefit = 1 )
1112+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1113+ controlFoldingReshapes (std::move(foldReshapes)) {}
1114+
1115+ LogicalResult matchAndRewrite (tensor::ExpandShapeOp expandOp,
1116+ PatternRewriter &rewriter) const override {
1117+ tensor::PadOp padOp = expandOp.getSrc ().getDefiningOp <tensor::PadOp>();
1118+ if (!padOp)
1119+ return failure ();
1120+
1121+ if (!padOp->hasOneUse ())
1122+ return failure ();
1123+
1124+ if (!controlFoldingReshapes (&expandOp.getSrcMutable ())) {
1125+ return rewriter.notifyMatchFailure (expandOp,
1126+ " fusion blocked by control function" );
1127+ }
1128+
1129+ SmallVector<ReassociationIndices> reassociations =
1130+ expandOp.getReassociationIndices ();
1131+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad ();
1132+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad ();
1133+
1134+ auto isZeroPadding = [](OpFoldResult padValue) -> bool {
1135+ if (auto attr = dyn_cast<Attribute>(padValue)) {
1136+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
1137+ return intAttr.getInt () == 0 ;
1138+ }
1139+
1140+ if (auto val = dyn_cast<Value>(padValue)) {
1141+ if (auto constOp = val.getDefiningOp <arith::ConstantOp>()) {
1142+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue ()))
1143+ return attr.getInt () == 0 ;
1144+ }
1145+ }
1146+
1147+ // when padding is dynamic and not constant, we don't know if it's zero or
1148+ // not. so we return false here.
1149+ return false ;
1150+ };
1151+
1152+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1153+ OpFoldResult l = low[idx];
1154+ OpFoldResult h = high[idx];
1155+ if (reInd.size () != 1 && (!isZeroPadding (l) || !isZeroPadding (h)))
1156+ return failure ();
1157+ }
1158+
1159+ SmallVector<OpFoldResult> newLow, newHigh;
1160+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1161+ for (size_t i = 0 ; i < reInd.size (); ++i) {
1162+ newLow.push_back (padOp.getMixedLowPad ()[idx]);
1163+ newHigh.push_back (padOp.getMixedHighPad ()[idx]);
1164+ }
1165+ }
1166+
1167+ Location loc = expandOp.getLoc ();
1168+ auto finalType = cast<RankedTensorType>(expandOp.getType ());
1169+ ArrayRef<int64_t > finalShape = finalType.getShape ();
1170+
1171+ SmallVector<OpFoldResult> expandedShape;
1172+ for (int64_t dimSize : finalShape) {
1173+ if (dimSize == ShapedType::kDynamic ) {
1174+ expandedShape.push_back (OpFoldResult{});
1175+ } else {
1176+ expandedShape.push_back (rewriter.getI64IntegerAttr (dimSize));
1177+ }
1178+ }
1179+
1180+ for (auto [inDimIdx, outGroup] : llvm::enumerate (reassociations)) {
1181+ OpFoldResult l = low[inDimIdx];
1182+ OpFoldResult h = high[inDimIdx];
1183+
1184+ if (!isZeroPadding (l) || !isZeroPadding (h)) {
1185+ auto srcType = cast<RankedTensorType>(padOp.getSource ().getType ());
1186+ int64_t originalSize = srcType.getDimSize (inDimIdx);
1187+
1188+ OpFoldResult originalSizeOFR;
1189+ if (originalSize == ShapedType::kDynamic ) {
1190+ Value orgSizeVal =
1191+ rewriter.create <tensor::DimOp>(loc, padOp.getSource (), inDimIdx);
1192+ originalSizeOFR = orgSizeVal;
1193+ } else {
1194+ originalSizeOFR = rewriter.getI64IntegerAttr (originalSize);
1195+ }
1196+
1197+ for (auto outDimIdx : outGroup) {
1198+ expandedShape[outDimIdx] = originalSizeOFR;
1199+ }
1200+ }
1201+ }
1202+
1203+ for (auto [outDimIdx, dimSize] : llvm::enumerate (finalShape)) {
1204+ if (dimSize == ShapedType::kDynamic &&
1205+ !isa<Value>(expandedShape[outDimIdx]) &&
1206+ !isa<Attribute>(expandedShape[outDimIdx])) {
1207+ Value actualSize =
1208+ rewriter.create <tensor::DimOp>(loc, expandOp.getSrc (), outDimIdx);
1209+ expandedShape[outDimIdx] = actualSize;
1210+ }
1211+ }
1212+
1213+ SmallVector<int64_t > staticExpandedShape;
1214+ for (OpFoldResult dim : expandedShape) {
1215+ if (auto attr = dyn_cast<Attribute>(dim)) {
1216+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
1217+ staticExpandedShape.push_back (intAttr.getInt ());
1218+ } else {
1219+ staticExpandedShape.push_back (ShapedType::kDynamic );
1220+ }
1221+ } else {
1222+ staticExpandedShape.push_back (ShapedType::kDynamic );
1223+ }
1224+ }
1225+
1226+ auto newExpandOp = rewriter.create <tensor::ExpandShapeOp>(
1227+ loc,
1228+ RankedTensorType::get (staticExpandedShape,
1229+ padOp.getSource ().getType ().getElementType ()),
1230+ padOp.getSource (), reassociations);
1231+
1232+ auto newPadOp = rewriter.create <tensor::PadOp>(
1233+ loc, expandOp.getType (), newExpandOp.getResult (), newLow, newHigh,
1234+ padOp.getConstantPaddingValue (), padOp.getNofold ());
1235+
1236+ rewriter.replaceOp (expandOp, newPadOp.getResult ());
1237+ return success ();
1238+ }
1239+
1240+ private:
1241+ ControlFusionFn controlFoldingReshapes;
1242+ };
1243+
11041244// / Pattern to fold a tensor.expand_shape op with its producer generic op
11051245// / by expanding the dimensionality of the loop in the producer op.
11061246struct FoldReshapeWithGenericOpByExpansion
@@ -2249,6 +2389,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
22492389 controlFoldingReshapes);
22502390 patterns.add <FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext (),
22512391 controlFoldingReshapes);
2392+ patterns.add <FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext (),
2393+ controlFoldingReshapes);
22522394 patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
22532395 controlFoldingReshapes);
22542396}
0 commit comments