@@ -1050,19 +1050,24 @@ struct PadDimInfo {
10501050
10511051// / Computes the expanded padding information for the given pad operation based
10521052// / on the provided expanded shape and reassociation indices. Returns a list of
1053- // / PaddedDimInfo containing the low and high padding amounts and the padded
1053+ // / PadDimInfo containing the low and high padding amounts and the padded
10541054// / size for each dimension, or failure if the expansion is not possible.
10551055static FailureOr<PadDimInfo>
10561056computeExpandedPadding (tensor::PadOp padOp, ArrayRef<int64_t > expandedShape,
10571057 ArrayRef<ReassociationIndices> reassociations,
10581058 PatternRewriter &rewriter) {
1059- ArrayRef<int64_t > low = padOp.getStaticLow ();
1060- ArrayRef<int64_t > high = padOp.getStaticHigh ();
1059+ // If the padding value depends on the index values of the pad operation,
1060+ // then it may not be valid to expand the dimensions, since it will change
1061+ // the index values on which the padding value depends.
1062+ if (!padOp.getConstantPaddingValue ())
1063+ return failure ();
10611064
10621065 // Expanded dimensions cannot have padding because the resulting padding may
10631066 // not be representable by a tensor.pad op. There are some special cases where
10641067 // it is possible (like expanding unit dims), but supporting these cases is
10651068 // NYI, so disallow it for now.
1069+ ArrayRef<int64_t > low = padOp.getStaticLow ();
1070+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
10661071 for (auto [reInd, l, h] : llvm::zip_equal (reassociations, low, high)) {
10671072 if (reInd.size () != 1 && (l != 0 || h != 0 ))
10681073 return failure ();
@@ -1101,8 +1106,6 @@ class FoldPadWithProducerReshapeOpByExpansion
11011106 padOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
11021107 if (!reshapeOp)
11031108 return failure ();
1104- if (!reshapeOp->hasOneUse ())
1105- return failure ();
11061109
11071110 if (!controlFoldingReshapes (&padOp.getSourceMutable ())) {
11081111 return rewriter.notifyMatchFailure (padOp,
@@ -1116,7 +1119,7 @@ class FoldPadWithProducerReshapeOpByExpansion
11161119 padOp, expandedType.getShape (), reassociations, rewriter);
11171120 if (failed (maybeExpandedPadding))
11181121 return failure ();
1119- PadDimInfo expandedPadding = maybeExpandedPadding.value ();
1122+ PadDimInfo & expandedPadding = maybeExpandedPadding.value ();
11201123
11211124 Location loc = padOp->getLoc ();
11221125 RankedTensorType expandedPaddedType =
@@ -1137,12 +1140,12 @@ class FoldPadWithProducerReshapeOpByExpansion
11371140 ControlFusionFn controlFoldingReshapes;
11381141};
11391142
1140- class FoldExpandShapeWithProducerPadOp
1143+ class FoldReshapeWithProducerPadOpByExpansion
11411144 : public OpRewritePattern<tensor::ExpandShapeOp> {
11421145public:
1143- FoldExpandShapeWithProducerPadOp (MLIRContext *context,
1144- ControlFusionFn foldReshapes,
1145- PatternBenefit benefit = 1 )
1146+ FoldReshapeWithProducerPadOpByExpansion (MLIRContext *context,
1147+ ControlFusionFn foldReshapes,
1148+ PatternBenefit benefit = 1 )
11461149 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
11471150 controlFoldingReshapes (std::move(foldReshapes)) {}
11481151
@@ -1151,8 +1154,6 @@ class FoldExpandShapeWithProducerPadOp
11511154 tensor::PadOp padOp = expandOp.getSrc ().getDefiningOp <tensor::PadOp>();
11521155 if (!padOp)
11531156 return failure ();
1154- if (!padOp->hasOneUse ())
1155- return failure ();
11561157
11571158 if (!controlFoldingReshapes (&expandOp.getSrcMutable ())) {
11581159 return rewriter.notifyMatchFailure (expandOp,
@@ -1166,7 +1167,7 @@ class FoldExpandShapeWithProducerPadOp
11661167 padOp, expandedType.getShape (), reassociations, rewriter);
11671168 if (failed (maybeExpandedPadding))
11681169 return failure ();
1169- PadDimInfo expandedPadding = maybeExpandedPadding.value ();
1170+ PadDimInfo & expandedPadding = maybeExpandedPadding.value ();
11701171
11711172 Location loc = expandOp->getLoc ();
11721173 SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape ();
@@ -2031,13 +2032,18 @@ static FailureOr<PadDimInfo>
20312032computeCollapsedPadding (tensor::PadOp padOp,
20322033 ArrayRef<ReassociationIndices> reassociations,
20332034 PatternRewriter &rewriter) {
2034- ArrayRef<int64_t > low = padOp.getStaticLow ();
2035- ArrayRef<int64_t > high = padOp.getStaticHigh ();
2035+ // If the padding value depends on the index values of the pad operation,
2036+ // then it may not be valid to collapse the dimensions, since it will change
2037+ // the index values on which the padding value depends.
2038+ if (!padOp.getConstantPaddingValue ())
2039+ return failure ();
20362040
20372041 // Collapsed dimensions cannot have padding because this can produce strided
20382042 // padding that isn't representable by a tensor.pad op. There are some special
2039- // cases where it it possible (like collapsing unit dims), but supporting
2043+ // cases where it is possible (like collapsing unit dims), but supporting
20402044 // these cases is NYI, so disallow it for now.
2045+ ArrayRef<int64_t > low = padOp.getStaticLow ();
2046+ ArrayRef<int64_t > high = padOp.getStaticHigh ();
20412047 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
20422048 for (int64_t dim : reInd) {
20432049 if ((low[dim] != 0 || high[dim] != 0 ) && reInd.size () != 1 )
@@ -2053,10 +2059,12 @@ computeCollapsedPadding(tensor::PadOp padOp,
20532059
20542060 // Update padding for dimensions that are not being collapsed, and compute
20552061 // the collapsed padded shape.
2062+ SmallVector<OpFoldResult> mixedLowPad (padOp.getMixedLowPad ());
2063+ SmallVector<OpFoldResult> mixedHighPad (padOp.getMixedHighPad ());
20562064 for (auto [idx, reInd] : llvm::enumerate (reassociations)) {
20572065 if (reInd.size () == 1 ) {
2058- padDimInfo.lowPad [idx] = padOp. getMixedLowPad () [reInd[0 ]];
2059- padDimInfo.highPad [idx] = padOp. getMixedHighPad () [reInd[0 ]];
2066+ padDimInfo.lowPad [idx] = mixedLowPad [reInd[0 ]];
2067+ padDimInfo.highPad [idx] = mixedHighPad [reInd[0 ]];
20602068 }
20612069 SaturatedInteger collapsedSize = SaturatedInteger::wrap (1 );
20622070 for (int64_t dim : reInd) {
@@ -2084,8 +2092,6 @@ class FoldPadWithProducerReshapeOpByCollapsing
20842092 padOp.getSource ().getDefiningOp <tensor::ExpandShapeOp>();
20852093 if (!reshapeOp)
20862094 return failure ();
2087- if (!reshapeOp->hasOneUse ())
2088- return failure ();
20892095
20902096 if (!controlFoldingReshapes (&padOp.getSourceMutable ())) {
20912097 return rewriter.notifyMatchFailure (padOp,
@@ -2098,7 +2104,7 @@ class FoldPadWithProducerReshapeOpByCollapsing
20982104 computeCollapsedPadding (padOp, reassociations, rewriter);
20992105 if (failed (maybeCollapsedPadding))
21002106 return failure ();
2101- PadDimInfo collapsedPadding = maybeCollapsedPadding.value ();
2107+ PadDimInfo & collapsedPadding = maybeCollapsedPadding.value ();
21022108
21032109 SmallVector<OpFoldResult> expandedPaddedSizes =
21042110 reshapeOp.getMixedOutputShape ();
@@ -2147,8 +2153,6 @@ class FoldReshapeWithProducerPadOpByCollapsing
21472153 tensor::PadOp padOp = reshapeOp.getSrc ().getDefiningOp <tensor::PadOp>();
21482154 if (!padOp)
21492155 return failure ();
2150- if (!padOp->hasOneUse ())
2151- return failure ();
21522156
21532157 if (!controlFoldingReshapes (&reshapeOp.getSrcMutable ())) {
21542158 return rewriter.notifyMatchFailure (padOp,
@@ -2162,7 +2166,7 @@ class FoldReshapeWithProducerPadOpByCollapsing
21622166 computeCollapsedPadding (padOp, reassociations, rewriter);
21632167 if (failed (maybeCollapsedPadding))
21642168 return failure ();
2165- PadDimInfo collapsedPadding = maybeCollapsedPadding.value ();
2169+ PadDimInfo & collapsedPadding = maybeCollapsedPadding.value ();
21662170
21672171 Location loc = reshapeOp->getLoc ();
21682172 auto newCollapseOp = tensor::CollapseShapeOp::create (
@@ -2420,8 +2424,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
24202424 controlFoldingReshapes);
24212425 patterns.add <FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext (),
24222426 controlFoldingReshapes);
2423- patterns.add <FoldExpandShapeWithProducerPadOp >(patterns.getContext (),
2424- controlFoldingReshapes);
2427+ patterns.add <FoldReshapeWithProducerPadOpByExpansion >(patterns.getContext (),
2428+ controlFoldingReshapes);
24252429 patterns.add <FoldWithProducerReshapeOpByExpansion>(patterns.getContext (),
24262430 controlFoldingReshapes);
24272431}
0 commit comments