@@ -1038,6 +1038,62 @@ class FoldWithProducerReshapeOpByExpansion
10381038 ControlFusionFn controlFoldingReshapes;
10391039};
10401040
1041+ // / Carries information about a padded dimension.
1042+ struct PadDimInfo {
1043+ // The resulting shape after padding each dimension.
1044+ SmallVector<int64_t > paddedShape;
1045+
1046+ // Low and high padding amounts for each dimension.
1047+ SmallVector<OpFoldResult> lowPad;
1048+ SmallVector<OpFoldResult> highPad;
1049+ };
1050+
1051+ // / Computes the expanded padding information for the given pad operation based
1052+ // / on the provided expanded shape and reassociation indices. Returns a list of
1053+ // / PadDimInfo containing the low and high padding amounts and the padded
1054+ // / size for each dimension, or failure if the expansion is not possible.
1055+ static FailureOr<PadDimInfo>
1056+ computeExpandedPadding (tensor::PadOp padOp, ArrayRef<int64_t > expandedShape,
1057+ ArrayRef<ReassociationIndices> reassociations,
1058+ PatternRewriter &rewriter) {
1059+ // If the padding value depends on the index values of the pad operation,
1060+ // then it may not be valid to expand the dimensions, since it will change
1061+ // the index values on which the padding value depends. This is not currently
1062+ // supported by the pad expansion patterns, but it could be implemented
1063+ // similarly to the expansion of linalg.generic ops with linalg.index ops in
1064+ // the body, as is done in `updateExpandedGenericOpRegion`.
1065+ if (!padOp.getConstantPaddingValue ())
1066+ return failure ();
1067+
1068+ // Expanded dimensions cannot have padding because the resulting padding may
1069+ // not be representable by a tensor.pad op. There are some special cases where
1070+ // it is possible (like expanding unit dims), but supporting these cases is
1071+ // NYI, so disallow it for now.
1072+ ArrayRef<int64_t > low = padOp.getStaticLow ();
1073+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
1074+ for (auto [reInd, l, h] : llvm::zip_equal (reassociations, low, high)) {
1075+ if (reInd.size () != 1 && (l != 0 || h != 0 ))
1076+ return failure ();
1077+ }
1078+
1079+ SmallVector<OpFoldResult> mixedLowPad (padOp.getMixedLowPad ());
1080+ SmallVector<OpFoldResult> mixedHighPad (padOp.getMixedHighPad ());
1081+ ArrayRef<int64_t > paddedShape = padOp.getResultType ().getShape ();
1082+ PadDimInfo padDimInfo;
1083+ padDimInfo.paddedShape .assign (expandedShape);
1084+ padDimInfo.lowPad .assign (expandedShape.size (), rewriter.getIndexAttr (0 ));
1085+ padDimInfo.highPad .assign (expandedShape.size (), rewriter.getIndexAttr (0 ));
1086+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1087+ if (reInd.size () == 1 ) {
1088+ padDimInfo.paddedShape [reInd[0 ]] = paddedShape[idx];
1089+ padDimInfo.lowPad [reInd[0 ]] = mixedLowPad[idx];
1090+ padDimInfo.highPad [reInd[0 ]] = mixedHighPad[idx];
1091+ }
1092+ }
1093+
1094+ return padDimInfo;
1095+ }
1096+
10411097class FoldPadWithProducerReshapeOpByExpansion
10421098 : public OpRewritePattern<tensor::PadOp> {
10431099public:
@@ -1053,46 +1109,96 @@ class FoldPadWithProducerReshapeOpByExpansion
10531109 padOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
10541110 if (!reshapeOp)
10551111 return failure ();
1056- if (!reshapeOp->hasOneUse ())
1057- return failure ();
10581112
10591113 if (!controlFoldingReshapes (&padOp.getSourceMutable ())) {
10601114 return rewriter.notifyMatchFailure (padOp,
10611115 " fusion blocked by control function" );
10621116 }
10631117
1064- ArrayRef<int64_t > low = padOp.getStaticLow ();
1065- ArrayRef<int64_t > high = padOp.getStaticHigh ();
1118+ RankedTensorType expandedType = reshapeOp.getSrcType ();
10661119 SmallVector<ReassociationIndices> reassociations =
10671120 reshapeOp.getReassociationIndices ();
1121+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding (
1122+ padOp, expandedType.getShape (), reassociations, rewriter);
1123+ if (failed (maybeExpandedPadding))
1124+ return failure ();
1125+ PadDimInfo &expandedPadding = maybeExpandedPadding.value ();
10681126
1069- for (auto [reInd, l, h] : llvm::zip_equal (reassociations, low, high)) {
1070- if (reInd.size () != 1 && (l != 0 || h != 0 ))
1071- return failure ();
1127+ Location loc = padOp->getLoc ();
1128+ RankedTensorType expandedPaddedType =
1129+ padOp.getResultType ().clone (expandedPadding.paddedShape );
1130+
1131+ auto newPadOp = tensor::PadOp::create (
1132+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc (),
1133+ expandedPadding.lowPad , expandedPadding.highPad ,
1134+ padOp.getConstantPaddingValue (), padOp.getNofold ());
1135+
1136+ rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
1137+ padOp, padOp.getResultType (), newPadOp.getResult (), reassociations);
1138+
1139+ return success ();
1140+ }
1141+
1142+ private:
1143+ ControlFusionFn controlFoldingReshapes;
1144+ };
1145+
1146+ class FoldReshapeWithProducerPadOpByExpansion
1147+ : public OpRewritePattern<tensor::ExpandShapeOp> {
1148+ public:
1149+ FoldReshapeWithProducerPadOpByExpansion (MLIRContext *context,
1150+ ControlFusionFn foldReshapes,
1151+ PatternBenefit benefit = 1 )
1152+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1153+ controlFoldingReshapes (std::move(foldReshapes)) {}
1154+
1155+ LogicalResult matchAndRewrite (tensor::ExpandShapeOp expandOp,
1156+ PatternRewriter &rewriter) const override {
1157+ tensor::PadOp padOp = expandOp.getSrc ().getDefiningOp <tensor::PadOp>();
1158+ if (!padOp)
1159+ return failure ();
1160+
1161+ if (!controlFoldingReshapes (&expandOp.getSrcMutable ())) {
1162+ return rewriter.notifyMatchFailure (expandOp,
1163+ " fusion blocked by control function" );
10721164 }
10731165
1074- SmallVector<OpFoldResult> newLow, newHigh;
1075- RankedTensorType expandedType = reshapeOp.getSrcType ();
1076- RankedTensorType paddedType = padOp.getResultType ();
1077- SmallVector<int64_t > expandedPaddedShape (expandedType.getShape ());
1166+ RankedTensorType expandedType = expandOp.getResultType ();
1167+ SmallVector<ReassociationIndices> reassociations =
1168+ expandOp.getReassociationIndices ();
1169+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding (
1170+ padOp, expandedType.getShape (), reassociations, rewriter);
1171+ if (failed (maybeExpandedPadding))
1172+ return failure ();
1173+ PadDimInfo &expandedPadding = maybeExpandedPadding.value ();
1174+
1175+ Location loc = expandOp->getLoc ();
1176+ SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape ();
1177+ SmallVector<int64_t > newExpandedShape (expandedType.getShape ());
1178+ rewriter.setInsertionPointAfterValue (padOp.getSource ());
1179+ SmallVector<OpFoldResult> padSrcSizes =
1180+ tensor::getMixedSizes (rewriter, loc, padOp.getSource ());
10781181 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
1182+ // We know that any reassociation with multiple dims is not padded because
1183+ // of the requirements of computeExpandedPadding.
10791184 if (reInd.size () == 1 ) {
1080- expandedPaddedShape[reInd[0 ]] = paddedType.getShape ()[idx];
1081- }
1082- for (size_t i = 0 ; i < reInd.size (); ++i) {
1083- newLow.push_back (padOp.getMixedLowPad ()[idx]);
1084- newHigh.push_back (padOp.getMixedHighPad ()[idx]);
1185+ newExpandedShape[reInd[0 ]] = padOp.getSourceType ().getDimSize (idx);
1186+ newExpandedSizes[reInd[0 ]] = padSrcSizes[idx];
10851187 }
10861188 }
1087-
1088- Location loc = padOp->getLoc ();
1089- RankedTensorType expandedPaddedType = paddedType.clone (expandedPaddedShape);
1189+ RankedTensorType newExpandedType = expandedType.clone (newExpandedShape);
1190+ auto newExpandOp = tensor::ExpandShapeOp::create (
1191+ rewriter, loc, newExpandedType, padOp.getSource (), reassociations,
1192+ newExpandedSizes);
1193+ RankedTensorType expandedPaddedType =
1194+ padOp.getResultType ().clone (expandedPadding.paddedShape );
1195+ rewriter.setInsertionPoint (expandOp);
10901196 auto newPadOp = tensor::PadOp::create (
1091- rewriter, loc, expandedPaddedType, reshapeOp.getSrc (), newLow, newHigh,
1197+ rewriter, loc, expandedPaddedType, newExpandOp.getResult (),
1198+ expandedPadding.lowPad , expandedPadding.highPad ,
10921199 padOp.getConstantPaddingValue (), padOp.getNofold ());
10931200
1094- rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
1095- padOp, padOp.getResultType (), newPadOp.getResult (), reassociations);
1201+ rewriter.replaceOp (expandOp, newPadOp.getResult ());
10961202
10971203 return success ();
10981204 }
@@ -1921,6 +2027,62 @@ struct FoldReshapeWithGenericOpByCollapsing
19212027 ControlFusionFn controlFoldingReshapes;
19222028};
19232029
2030+ // / Computes the collapsed padding information for the given pad operation based
2031+ // / on the provided collapsed shape and reassociation indices. Returns a
2032+ // / PadDimInfo containing the low and high padding amounts and the collapsed
2033+ // / shape for each dimension, or failure if the collapse is not possible.
2034+ static FailureOr<PadDimInfo>
2035+ computeCollapsedPadding (tensor::PadOp padOp,
2036+ ArrayRef<ReassociationIndices> reassociations,
2037+ PatternRewriter &rewriter) {
2038+ // If the padding value depends on the index values of the pad operation,
2039+ // then it may not be valid to collapse the dimensions, since it will change
2040+ // the index values on which the padding value depends. This is not currently
2041+ // supported by the pad collapsing patterns, but it could be implemented
2042+ // similarly to the collapsing of linalg.generic ops with linalg.index ops in
2043+ // the body, as is done in `generateCollapsedIndexingRegion`.
2044+ if (!padOp.getConstantPaddingValue ())
2045+ return failure ();
2046+
2047+ // Collapsed dimensions cannot have padding because this can produce strided
2048+ // padding that isn't representable by a tensor.pad op. There are some special
2049+ // cases where it is possible (like collapsing unit dims), but supporting
2050+ // these cases is NYI, so disallow it for now.
2051+ ArrayRef<int64_t > low = padOp.getStaticLow ();
2052+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
2053+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
2054+ for (int64_t dim : reInd) {
2055+ if ((low[dim] != 0 || high[dim] != 0 ) && reInd.size () != 1 )
2056+ return failure ();
2057+ }
2058+ }
2059+
2060+ // Initialize padding values for collapsed tensors with zeros
2061+ ArrayRef<int64_t > expandedPaddedShape = padOp.getType ().getShape ();
2062+ PadDimInfo padDimInfo;
2063+ padDimInfo.lowPad .assign (reassociations.size (), rewriter.getIndexAttr (0 ));
2064+ padDimInfo.highPad .assign (reassociations.size (), rewriter.getIndexAttr (0 ));
2065+
2066+ // Update padding for dimensions that are not being collapsed, and compute
2067+ // the collapsed padded shape.
2068+ SmallVector<OpFoldResult> mixedLowPad (padOp.getMixedLowPad ());
2069+ SmallVector<OpFoldResult> mixedHighPad (padOp.getMixedHighPad ());
2070+ for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
2071+ if (reInd.size () == 1 ) {
2072+ padDimInfo.lowPad [idx] = mixedLowPad[reInd[0 ]];
2073+ padDimInfo.highPad [idx] = mixedHighPad[reInd[0 ]];
2074+ }
2075+ SaturatedInteger collapsedSize = SaturatedInteger::wrap (1 );
2076+ for (int64_t dim : reInd) {
2077+ collapsedSize =
2078+ collapsedSize * SaturatedInteger::wrap (expandedPaddedShape[dim]);
2079+ }
2080+ padDimInfo.paddedShape .push_back (collapsedSize.asInteger ());
2081+ }
2082+
2083+ return padDimInfo;
2084+ }
2085+
19242086class FoldPadWithProducerReshapeOpByCollapsing
19252087 : public OpRewritePattern<tensor::PadOp> {
19262088public:
@@ -1936,57 +2098,40 @@ class FoldPadWithProducerReshapeOpByCollapsing
19362098 padOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
19372099 if (!reshapeOp)
19382100 return failure ();
1939- if (!reshapeOp->hasOneUse ())
1940- return failure ();
19412101
19422102 if (!controlFoldingReshapes (&padOp.getSourceMutable ())) {
19432103 return rewriter.notifyMatchFailure (padOp,
19442104 " fusion blocked by control function" );
19452105 }
19462106
1947- ArrayRef<int64_t > low = padOp.getStaticLow ();
1948- ArrayRef<int64_t > high = padOp.getStaticHigh ();
19492107 SmallVector<ReassociationIndices> reassociations =
19502108 reshapeOp.getReassociationIndices ();
2109+ FailureOr<PadDimInfo> maybeCollapsedPadding =
2110+ computeCollapsedPadding (padOp, reassociations, rewriter);
2111+ if (failed (maybeCollapsedPadding))
2112+ return failure ();
2113+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value ();
19512114
1952- for (auto reInd : reassociations) {
1953- if (reInd.size () == 1 )
1954- continue ;
1955- if (llvm::any_of (reInd, [&](int64_t ind) {
1956- return low[ind] != 0 || high[ind] != 0 ;
1957- })) {
1958- return failure ();
1959- }
1960- }
1961-
1962- SmallVector<OpFoldResult> newLow, newHigh;
1963- RankedTensorType collapsedType = reshapeOp.getSrcType ();
1964- RankedTensorType paddedType = padOp.getResultType ();
1965- SmallVector<int64_t > collapsedPaddedShape (collapsedType.getShape ());
1966- SmallVector<OpFoldResult> expandedPaddedSizes (
1967- getMixedValues (reshapeOp.getStaticOutputShape (),
1968- reshapeOp.getOutputShape (), rewriter));
2115+ SmallVector<OpFoldResult> expandedPaddedSizes =
2116+ reshapeOp.getMixedOutputShape ();
19692117 AffineExpr d0, d1, d2;
19702118 bindDims (rewriter.getContext (), d0, d1, d2);
19712119 auto addMap = AffineMap::get (3 , 0 , {d0 + d1 + d2});
19722120 Location loc = reshapeOp->getLoc ();
1973- for (auto [idx, reInd ] : llvm::enumerate (reassociations)) {
1974- OpFoldResult l = padOp. getMixedLowPad ()[reInd[ 0 ]];
1975- OpFoldResult h = padOp. getMixedHighPad ()[reInd[ 0 ]];
2121+ for (auto [reInd, l, h ] :
2122+ llvm::zip_equal (reassociations, collapsedPadding. lowPad ,
2123+ collapsedPadding. highPad )) {
19762124 if (reInd.size () == 1 ) {
1977- collapsedPaddedShape[idx] = paddedType.getShape ()[reInd[0 ]];
1978- OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply (
2125+ expandedPaddedSizes[reInd[0 ]] = affine::makeComposedFoldedAffineApply (
19792126 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0 ]]});
1980- expandedPaddedSizes[reInd[0 ]] = paddedSize;
19812127 }
1982- newLow.push_back (l);
1983- newHigh.push_back (h);
19842128 }
19852129
19862130 RankedTensorType collapsedPaddedType =
1987- paddedType. clone (collapsedPaddedShape );
2131+ padOp. getType (). clone (collapsedPadding. paddedShape );
19882132 auto newPadOp = tensor::PadOp::create (
1989- rewriter, loc, collapsedPaddedType, reshapeOp.getSrc (), newLow, newHigh,
2133+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc (),
2134+ collapsedPadding.lowPad , collapsedPadding.highPad ,
19902135 padOp.getConstantPaddingValue (), padOp.getNofold ());
19912136
19922137 rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
@@ -2000,6 +2145,52 @@ class FoldPadWithProducerReshapeOpByCollapsing
20002145 ControlFusionFn controlFoldingReshapes;
20012146};
20022147
2148+ class FoldReshapeWithProducerPadOpByCollapsing
2149+ : public OpRewritePattern<tensor::CollapseShapeOp> {
2150+ public:
2151+ FoldReshapeWithProducerPadOpByCollapsing (MLIRContext *context,
2152+ ControlFusionFn foldReshapes,
2153+ PatternBenefit benefit = 1 )
2154+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2155+ controlFoldingReshapes (std::move(foldReshapes)) {}
2156+
2157+ LogicalResult matchAndRewrite (tensor::CollapseShapeOp reshapeOp,
2158+ PatternRewriter &rewriter) const override {
2159+ tensor::PadOp padOp = reshapeOp.getSrc ().getDefiningOp <tensor::PadOp>();
2160+ if (!padOp)
2161+ return failure ();
2162+
2163+ if (!controlFoldingReshapes (&reshapeOp.getSrcMutable ())) {
2164+ return rewriter.notifyMatchFailure (padOp,
2165+ " fusion blocked by control function" );
2166+ }
2167+
2168+ SmallVector<ReassociationIndices> reassociations =
2169+ reshapeOp.getReassociationIndices ();
2170+ RankedTensorType collapsedPaddedType = reshapeOp.getResultType ();
2171+ FailureOr<PadDimInfo> maybeCollapsedPadding =
2172+ computeCollapsedPadding (padOp, reassociations, rewriter);
2173+ if (failed (maybeCollapsedPadding))
2174+ return failure ();
2175+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value ();
2176+
2177+ Location loc = reshapeOp->getLoc ();
2178+ auto newCollapseOp = tensor::CollapseShapeOp::create (
2179+ rewriter, loc, padOp.getSource (), reassociations);
2180+
2181+ auto newPadOp = tensor::PadOp::create (
2182+ rewriter, loc, collapsedPaddedType, newCollapseOp.getResult (),
2183+ collapsedPadding.lowPad , collapsedPadding.highPad ,
2184+ padOp.getConstantPaddingValue (), padOp.getNofold ());
2185+
2186+ rewriter.replaceOp (reshapeOp, newPadOp.getResult ());
2187+ return success ();
2188+ }
2189+
2190+ private:
2191+ ControlFusionFn controlFoldingReshapes;
2192+ };
2193+
20032194// / Pattern to collapse dimensions.
20042195template <typename LinalgType>
20052196class CollapseLinalgDimensions : public OpRewritePattern <LinalgType> {
@@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
22392430 controlFoldingReshapes);
22402431 patterns.add <FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext (),
22412432 controlFoldingReshapes);
2433+ patterns.add <FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext (),
2434+ controlFoldingReshapes);
22422435 patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
22432436 controlFoldingReshapes);
22442437}
@@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
22502443 controlFoldingReshapes);
22512444 patterns.add <FoldPadWithProducerReshapeOpByCollapsing>(
22522445 patterns.getContext (), controlFoldingReshapes);
2446+ patterns.add <FoldReshapeWithProducerPadOpByCollapsing>(
2447+ patterns.getContext (), controlFoldingReshapes);
22532448 patterns.add <FoldReshapeWithGenericOpByCollapsing>(patterns.getContext (),
22542449 controlFoldingReshapes);
22552450}
0 commit comments