@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
10381038 ControlFusionFn controlFoldingReshapes;
10391039};
10401040
1041+ // / Carries information about a padded dimension.
1042+ struct PadDimInfo {
1043+ // The resulting shape after padding each dimension.
1044+ SmallVector<int64_t > paddedShape;
1045+
1046+ // Low and high padding amounts for each dimension.
1047+ SmallVector<OpFoldResult> lowPad;
1048+ SmallVector<OpFoldResult> highPad;
1049+ };
1050+
1051+ // / Computes the expanded padding information for the given pad operation based
1052+ // / on the provided expanded shape and reassociation indices. Returns a list of
1053+ // / PaddedDimInfo containing the low and high padding amounts and the padded
1054+ // / size for each dimension, or failure if the expansion is not possible.
1055+ static FailureOr<PadDimInfo>
1056+ computeExpandedPadding (tensor::PadOp padOp, ArrayRef<int64_t > expandedShape,
1057+ ArrayRef<ReassociationIndices> reassociations,
1058+ PatternRewriter &rewriter) {
1059+ ArrayRef<int64_t > low = padOp.getStaticLow ();
1060+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
1061+
1062+ // Expanded dimensions cannot have padding because the resulting padding may
1063+ // not be representable by a tensor.pad op. There are some special cases where
1064+ // it is possible (like expanding unit dims), but supporting these cases is
1065+ // NYI, so disallow it for now.
1066+ for (auto [reInd, l, h] : llvm::zip_equal (reassociations, low, high)) {
1067+ if (reInd.size () != 1 && (l != 0 || h != 0 ))
1068+ return failure ();
1069+ }
1070+
1071+ SmallVector<OpFoldResult> mixedLowPad (padOp.getMixedLowPad ());
1072+ SmallVector<OpFoldResult> mixedHighPad (padOp.getMixedHighPad ());
1073+ ArrayRef<int64_t > paddedShape = padOp.getResultType ().getShape ();
1074+ PadDimInfo padDimInfo;
1075+ padDimInfo.paddedShape .assign (expandedShape);
1076+ padDimInfo.lowPad .assign (expandedShape.size (), rewriter.getIndexAttr (0 ));
1077+ padDimInfo.highPad .assign (expandedShape.size (), rewriter.getIndexAttr (0 ));
1078+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1079+ if (reInd.size () == 1 ) {
1080+ padDimInfo.paddedShape [reInd[0 ]] = paddedShape[idx];
1081+ padDimInfo.lowPad [reInd[0 ]] = mixedLowPad[idx];
1082+ padDimInfo.highPad [reInd[0 ]] = mixedHighPad[idx];
1083+ }
1084+ }
1085+
1086+ return padDimInfo;
1087+ }
1088+
10411089class FoldPadWithProducerReshapeOpByExpansion
10421090 : public OpRewritePattern<tensor::PadOp> {
10431091public:
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
10611109 " fusion blocked by control function" );
10621110 }
10631111
1064- ArrayRef<int64_t > low = padOp.getStaticLow ();
1065- ArrayRef<int64_t > high = padOp.getStaticHigh ();
1112+ RankedTensorType expandedType = reshapeOp.getSrcType ();
10661113 SmallVector<ReassociationIndices> reassociations =
10671114 reshapeOp.getReassociationIndices ();
1115+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding (
1116+ padOp, expandedType.getShape (), reassociations, rewriter);
1117+ if (failed (maybeExpandedPadding))
1118+ return failure ();
1119+ PadDimInfo expandedPadding = maybeExpandedPadding.value ();
10681120
1069- for (auto [reInd, l, h] : llvm::zip_equal (reassociations, low, high)) {
1070- if (reInd.size () != 1 && (l != 0 || h != 0 ))
1071- return failure ();
1121+ Location loc = padOp->getLoc ();
1122+ RankedTensorType expandedPaddedType =
1123+ padOp.getResultType ().clone (expandedPadding.paddedShape );
1124+
1125+ auto newPadOp = tensor::PadOp::create (
1126+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc (),
1127+ expandedPadding.lowPad , expandedPadding.highPad ,
1128+ padOp.getConstantPaddingValue (), padOp.getNofold ());
1129+
1130+ rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
1131+ padOp, padOp.getResultType (), newPadOp.getResult (), reassociations);
1132+
1133+ return success ();
1134+ }
1135+
1136+ private:
1137+ ControlFusionFn controlFoldingReshapes;
1138+ };
1139+
1140+ class FoldExpandShapeWithProducerPadOp
1141+ : public OpRewritePattern<tensor::ExpandShapeOp> {
1142+ public:
1143+ FoldExpandShapeWithProducerPadOp (MLIRContext *context,
1144+ ControlFusionFn foldReshapes,
1145+ PatternBenefit benefit = 1 )
1146+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1147+ controlFoldingReshapes (std::move(foldReshapes)) {}
1148+
1149+ LogicalResult matchAndRewrite (tensor::ExpandShapeOp expandOp,
1150+ PatternRewriter &rewriter) const override {
1151+ tensor::PadOp padOp = expandOp.getSrc ().getDefiningOp <tensor::PadOp>();
1152+ if (!padOp)
1153+ return failure ();
1154+ if (!padOp->hasOneUse ())
1155+ return failure ();
1156+
1157+ if (!controlFoldingReshapes (&expandOp.getSrcMutable ())) {
1158+ return rewriter.notifyMatchFailure (expandOp,
1159+ " fusion blocked by control function" );
10721160 }
10731161
1074- SmallVector<OpFoldResult> newLow, newHigh;
1075- RankedTensorType expandedType = reshapeOp.getSrcType ();
1076- RankedTensorType paddedType = padOp.getResultType ();
1077- SmallVector<int64_t > expandedPaddedShape (expandedType.getShape ());
1162+ RankedTensorType expandedType = expandOp.getResultType ();
1163+ SmallVector<ReassociationIndices> reassociations =
1164+ expandOp.getReassociationIndices ();
1165+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding (
1166+ padOp, expandedType.getShape (), reassociations, rewriter);
1167+ if (failed (maybeExpandedPadding))
1168+ return failure ();
1169+ PadDimInfo expandedPadding = maybeExpandedPadding.value ();
1170+
1171+ Location loc = expandOp->getLoc ();
1172+ SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape ();
1173+ SmallVector<int64_t > newExpandedShape (expandedType.getShape ());
1174+ rewriter.setInsertionPointAfterValue (padOp.getSource ());
1175+ SmallVector<OpFoldResult> padSrcSizes =
1176+ tensor::getMixedSizes (rewriter, loc, padOp.getSource ());
10781177 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1178+ // We know that any reassociation with multiple dims is not padded because
1179+ // of the requirements of computeExpandedPadding.
10791180 if (reInd.size () == 1 ) {
1080- expandedPaddedShape[reInd[0 ]] = paddedType.getShape ()[idx];
1081- }
1082- for (size_t i = 0 ; i < reInd.size (); ++i) {
1083- newLow.push_back (padOp.getMixedLowPad ()[idx]);
1084- newHigh.push_back (padOp.getMixedHighPad ()[idx]);
1181+ newExpandedShape[reInd[0 ]] = padOp.getSourceType ().getDimSize (idx);
1182+ newExpandedSizes[reInd[0 ]] = padSrcSizes[idx];
10851183 }
10861184 }
1087-
1088- Location loc = padOp->getLoc ();
1089- RankedTensorType expandedPaddedType = paddedType.clone (expandedPaddedShape);
1185+ RankedTensorType newExpandedType = expandedType.clone (newExpandedShape);
1186+ auto newExpandOp = tensor::ExpandShapeOp::create (
1187+ rewriter, loc, newExpandedType, padOp.getSource (), reassociations,
1188+ newExpandedSizes);
1189+ RankedTensorType expandedPaddedType =
1190+ padOp.getResultType ().clone (expandedPadding.paddedShape );
1191+ rewriter.setInsertionPoint (expandOp);
10901192 auto newPadOp = tensor::PadOp::create (
1091- rewriter, loc, expandedPaddedType, reshapeOp.getSrc (), newLow, newHigh,
1193+ rewriter, loc, expandedPaddedType, newExpandOp.getResult (),
1194+ expandedPadding.lowPad , expandedPadding.highPad ,
10921195 padOp.getConstantPaddingValue (), padOp.getNofold ());
10931196
1094- rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
1095- padOp, padOp.getResultType (), newPadOp.getResult (), reassociations);
1197+ rewriter.replaceOp (expandOp, newPadOp.getResult ());
10961198
10971199 return success ();
10981200 }
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
19212023 ControlFusionFn controlFoldingReshapes;
19222024};
19232025
2026+ // / Computes the collapsed padding information for the given pad operation based
2027+ // / on the provided collapsed shape and reassociation indices. Returns a
2028+ // / PadDimInfo containing the low and high padding amounts and the collapsed
2029+ // / shape for each dimension, or failure if the collapse is not possible.
2030+ static FailureOr<PadDimInfo>
2031+ computeCollapsedPadding (tensor::PadOp padOp,
2032+ ArrayRef<ReassociationIndices> reassociations,
2033+ PatternRewriter &rewriter) {
2034+ ArrayRef<int64_t > low = padOp.getStaticLow ();
2035+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
2036+
2037+ // Collapsed dimensions cannot have padding because this can produce strided
2038+ // padding that isn't representable by a tensor.pad op. There are some special
2039+ // cases where it it possible (like collapsing unit dims), but supporting
2040+ // these cases is NYI, so disallow it for now.
2041+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
2042+ for (int64_t dim : reInd) {
2043+ if ((low[dim] != 0 || high[dim] != 0 ) && reInd.size () != 1 )
2044+ return failure ();
2045+ }
2046+ }
2047+
2048+ // Initialize padding values for collapsed tensors with zeros
2049+ ArrayRef<int64_t > expandedPaddedShape = padOp.getType ().getShape ();
2050+ PadDimInfo padDimInfo;
2051+ padDimInfo.lowPad .assign (reassociations.size (), rewriter.getIndexAttr (0 ));
2052+ padDimInfo.highPad .assign (reassociations.size (), rewriter.getIndexAttr (0 ));
2053+
2054+ // Update padding for dimensions that are not being collapsed, and compute
2055+ // the collapsed padded shape.
2056+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
2057+ if (reInd.size () == 1 ) {
2058+ padDimInfo.lowPad [idx] = padOp.getMixedLowPad ()[reInd[0 ]];
2059+ padDimInfo.highPad [idx] = padOp.getMixedHighPad ()[reInd[0 ]];
2060+ }
2061+ SaturatedInteger collapsedSize = SaturatedInteger::wrap (1 );
2062+ for (int64_t dim : reInd) {
2063+ collapsedSize =
2064+ collapsedSize * SaturatedInteger::wrap (expandedPaddedShape[dim]);
2065+ }
2066+ padDimInfo.paddedShape .push_back (collapsedSize.asInteger ());
2067+ }
2068+
2069+ return padDimInfo;
2070+ }
2071+
19242072class FoldPadWithProducerReshapeOpByCollapsing
19252073 : public OpRewritePattern<tensor::PadOp> {
19262074public:
@@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing
19442092 " fusion blocked by control function" );
19452093 }
19462094
1947- ArrayRef<int64_t > low = padOp.getStaticLow ();
1948- ArrayRef<int64_t > high = padOp.getStaticHigh ();
19492095 SmallVector<ReassociationIndices> reassociations =
19502096 reshapeOp.getReassociationIndices ();
2097+ FailureOr<PadDimInfo> maybeCollapsedPadding =
2098+ computeCollapsedPadding (padOp, reassociations, rewriter);
2099+ if (failed (maybeCollapsedPadding))
2100+ return failure ();
2101+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value ();
19512102
1952- for (auto reInd : reassociations) {
1953- if (reInd.size () == 1 )
1954- continue ;
1955- if (llvm::any_of (reInd, [&](int64_t ind) {
1956- return low[ind] != 0 || high[ind] != 0 ;
1957- })) {
1958- return failure ();
1959- }
1960- }
1961-
1962- SmallVector<OpFoldResult> newLow, newHigh;
1963- RankedTensorType collapsedType = reshapeOp.getSrcType ();
1964- RankedTensorType paddedType = padOp.getResultType ();
1965- SmallVector<int64_t > collapsedPaddedShape (collapsedType.getShape ());
1966- SmallVector<OpFoldResult> expandedPaddedSizes (
1967- getMixedValues (reshapeOp.getStaticOutputShape (),
1968- reshapeOp.getOutputShape (), rewriter));
2103+ SmallVector<OpFoldResult> expandedPaddedSizes =
2104+ reshapeOp.getMixedOutputShape ();
19692105 AffineExpr d0, d1, d2;
19702106 bindDims (rewriter.getContext (), d0, d1, d2);
19712107 auto addMap = AffineMap::get (3 , 0 , {d0 + d1 + d2});
19722108 Location loc = reshapeOp->getLoc ();
1973- for (auto [idx, reInd ] : llvm::enumerate (reassociations)) {
1974- OpFoldResult l = padOp. getMixedLowPad ()[reInd[ 0 ]];
1975- OpFoldResult h = padOp. getMixedHighPad ()[reInd[ 0 ]];
2109+ for (auto [reInd, l, h ] :
2110+ llvm::zip_equal (reassociations, collapsedPadding. lowPad ,
2111+ collapsedPadding. highPad )) {
19762112 if (reInd.size () == 1 ) {
1977- collapsedPaddedShape[idx] = paddedType.getShape ()[reInd[0 ]];
1978- OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply (
2113+ expandedPaddedSizes[reInd[0 ]] = affine::makeComposedFoldedAffineApply (
19792114 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0 ]]});
1980- expandedPaddedSizes[reInd[0 ]] = paddedSize;
19812115 }
1982- newLow.push_back (l);
1983- newHigh.push_back (h);
19842116 }
19852117
19862118 RankedTensorType collapsedPaddedType =
1987- paddedType. clone (collapsedPaddedShape );
2119+ padOp. getType (). clone (collapsedPadding. paddedShape );
19882120 auto newPadOp = tensor::PadOp::create (
1989- rewriter, loc, collapsedPaddedType, reshapeOp.getSrc (), newLow, newHigh,
2121+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc (),
2122+ collapsedPadding.lowPad , collapsedPadding.highPad ,
19902123 padOp.getConstantPaddingValue (), padOp.getNofold ());
19912124
19922125 rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
@@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing
20002133 ControlFusionFn controlFoldingReshapes;
20012134};
20022135
2136+ class FoldReshapeWithProducerPadOpByCollapsing
2137+ : public OpRewritePattern<tensor::CollapseShapeOp> {
2138+ public:
2139+ FoldReshapeWithProducerPadOpByCollapsing (MLIRContext *context,
2140+ ControlFusionFn foldReshapes,
2141+ PatternBenefit benefit = 1 )
2142+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2143+ controlFoldingReshapes (std::move(foldReshapes)) {}
2144+
2145+ LogicalResult matchAndRewrite (tensor::CollapseShapeOp reshapeOp,
2146+ PatternRewriter &rewriter) const override {
2147+ tensor::PadOp padOp = reshapeOp.getSrc ().getDefiningOp <tensor::PadOp>();
2148+ if (!padOp)
2149+ return failure ();
2150+ if (!padOp->hasOneUse ())
2151+ return failure ();
2152+
2153+ if (!controlFoldingReshapes (&reshapeOp.getSrcMutable ())) {
2154+ return rewriter.notifyMatchFailure (padOp,
2155+ " fusion blocked by control function" );
2156+ }
2157+
2158+ SmallVector<ReassociationIndices> reassociations =
2159+ reshapeOp.getReassociationIndices ();
2160+ RankedTensorType collapsedPaddedType = reshapeOp.getResultType ();
2161+ FailureOr<PadDimInfo> maybeCollapsedPadding =
2162+ computeCollapsedPadding (padOp, reassociations, rewriter);
2163+ if (failed (maybeCollapsedPadding))
2164+ return failure ();
2165+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value ();
2166+
2167+ Location loc = reshapeOp->getLoc ();
2168+ auto newCollapseOp = tensor::CollapseShapeOp::create (
2169+ rewriter, loc, padOp.getSource (), reassociations);
2170+
2171+ auto newPadOp = tensor::PadOp::create (
2172+ rewriter, loc, collapsedPaddedType, newCollapseOp.getResult (),
2173+ collapsedPadding.lowPad , collapsedPadding.highPad ,
2174+ padOp.getConstantPaddingValue (), padOp.getNofold ());
2175+
2176+ rewriter.replaceOp (reshapeOp, newPadOp.getResult ());
2177+ return success ();
2178+ }
2179+
2180+ private:
2181+ ControlFusionFn controlFoldingReshapes;
2182+ };
2183+
20032184// / Pattern to collapse dimensions.
20042185template <typename LinalgType>
20052186class CollapseLinalgDimensions : public OpRewritePattern <LinalgType> {
@@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
22392420 controlFoldingReshapes);
22402421 patterns.add <FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext (),
22412422 controlFoldingReshapes);
2423+ patterns.add <FoldExpandShapeWithProducerPadOp>(patterns.getContext (),
2424+ controlFoldingReshapes);
22422425 patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
22432426 controlFoldingReshapes);
22442427}
@@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
22502433 controlFoldingReshapes);
22512434 patterns.add <FoldPadWithProducerReshapeOpByCollapsing>(
22522435 patterns.getContext (), controlFoldingReshapes);
2436+ patterns.add <FoldReshapeWithProducerPadOpByCollapsing>(
2437+ patterns.getContext (), controlFoldingReshapes);
22532438 patterns.add <FoldReshapeWithGenericOpByCollapsing>(patterns.getContext (),
22542439 controlFoldingReshapes);
22552440}
0 commit comments