Skip to content

Commit af2e246

Browse files
authored
[mlir] Add missing pad reshape propagation patterns (#168888)
The existing `FoldPadWithProducerReshapeOpByExpansion` and `FoldPadWithProducerReshapeOpByCollapsing` patterns did not cover all reshape propagation cases, because they only consider cases where the pad op is the consumer operation. This PR adds 2 new patterns to cover the cases where the pad op is the producer operation, which completes the propagation pattern set for pad op with expand_shape and collapse_shape. Note for integration: This PR also removes the single user restriction for the `FoldPadWithProducerReshapeOpByExpansion` and `FoldPadWithProducerReshapeOpByCollapsing` patterns, which leaves more control to the users of the pattern. If this constraint is needed, then it should be added to the control function for these patterns. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 7b2ee46 commit af2e246

File tree

3 files changed

+398
-53
lines changed

3 files changed

+398
-53
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 248 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,62 @@ class FoldWithProducerReshapeOpByExpansion
10381038
ControlFusionFn controlFoldingReshapes;
10391039
};
10401040

1041+
/// Carries information about a padded dimension.
1042+
struct PadDimInfo {
1043+
// The resulting shape after padding each dimension.
1044+
SmallVector<int64_t> paddedShape;
1045+
1046+
// Low and high padding amounts for each dimension.
1047+
SmallVector<OpFoldResult> lowPad;
1048+
SmallVector<OpFoldResult> highPad;
1049+
};
1050+
1051+
/// Computes the expanded padding information for the given pad operation based
1052+
/// on the provided expanded shape and reassociation indices. Returns a list of
1053+
/// PadDimInfo containing the low and high padding amounts and the padded
1054+
/// size for each dimension, or failure if the expansion is not possible.
1055+
static FailureOr<PadDimInfo>
1056+
computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1057+
ArrayRef<ReassociationIndices> reassociations,
1058+
PatternRewriter &rewriter) {
1059+
// If the padding value depends on the index values of the pad operation,
1060+
// then it may not be valid to expand the dimensions, since it will change
1061+
// the index values on which the padding value depends. This is not currently
1062+
// supported by the pad expansion patterns, but it could be implemented
1063+
// similarly to the expansion of linalg.generic ops with linalg.index ops in
1064+
// the body, as is done in `updateExpandedGenericOpRegion`.
1065+
if (!padOp.getConstantPaddingValue())
1066+
return failure();
1067+
1068+
// Expanded dimensions cannot have padding because the resulting padding may
1069+
// not be representable by a tensor.pad op. There are some special cases where
1070+
// it is possible (like expanding unit dims), but supporting these cases is
1071+
// NYI, so disallow it for now.
1072+
ArrayRef<int64_t> low = padOp.getStaticLow();
1073+
ArrayRef<int64_t> high = padOp.getStaticHigh();
1074+
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1075+
if (reInd.size() != 1 && (l != 0 || h != 0))
1076+
return failure();
1077+
}
1078+
1079+
SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1080+
SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1081+
ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1082+
PadDimInfo padDimInfo;
1083+
padDimInfo.paddedShape.assign(expandedShape);
1084+
padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1085+
padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1086+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1087+
if (reInd.size() == 1) {
1088+
padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1089+
padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1090+
padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1091+
}
1092+
}
1093+
1094+
return padDimInfo;
1095+
}
1096+
10411097
class FoldPadWithProducerReshapeOpByExpansion
10421098
: public OpRewritePattern<tensor::PadOp> {
10431099
public:
@@ -1053,46 +1109,96 @@ class FoldPadWithProducerReshapeOpByExpansion
10531109
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
10541110
if (!reshapeOp)
10551111
return failure();
1056-
if (!reshapeOp->hasOneUse())
1057-
return failure();
10581112

10591113
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
10601114
return rewriter.notifyMatchFailure(padOp,
10611115
"fusion blocked by control function");
10621116
}
10631117

1064-
ArrayRef<int64_t> low = padOp.getStaticLow();
1065-
ArrayRef<int64_t> high = padOp.getStaticHigh();
1118+
RankedTensorType expandedType = reshapeOp.getSrcType();
10661119
SmallVector<ReassociationIndices> reassociations =
10671120
reshapeOp.getReassociationIndices();
1121+
FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1122+
padOp, expandedType.getShape(), reassociations, rewriter);
1123+
if (failed(maybeExpandedPadding))
1124+
return failure();
1125+
PadDimInfo &expandedPadding = maybeExpandedPadding.value();
10681126

1069-
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070-
if (reInd.size() != 1 && (l != 0 || h != 0))
1071-
return failure();
1127+
Location loc = padOp->getLoc();
1128+
RankedTensorType expandedPaddedType =
1129+
padOp.getResultType().clone(expandedPadding.paddedShape);
1130+
1131+
auto newPadOp = tensor::PadOp::create(
1132+
rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1133+
expandedPadding.lowPad, expandedPadding.highPad,
1134+
padOp.getConstantPaddingValue(), padOp.getNofold());
1135+
1136+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1137+
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1138+
1139+
return success();
1140+
}
1141+
1142+
private:
1143+
ControlFusionFn controlFoldingReshapes;
1144+
};
1145+
1146+
class FoldReshapeWithProducerPadOpByExpansion
1147+
: public OpRewritePattern<tensor::ExpandShapeOp> {
1148+
public:
1149+
FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1150+
ControlFusionFn foldReshapes,
1151+
PatternBenefit benefit = 1)
1152+
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1153+
controlFoldingReshapes(std::move(foldReshapes)) {}
1154+
1155+
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1156+
PatternRewriter &rewriter) const override {
1157+
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1158+
if (!padOp)
1159+
return failure();
1160+
1161+
if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1162+
return rewriter.notifyMatchFailure(expandOp,
1163+
"fusion blocked by control function");
10721164
}
10731165

1074-
SmallVector<OpFoldResult> newLow, newHigh;
1075-
RankedTensorType expandedType = reshapeOp.getSrcType();
1076-
RankedTensorType paddedType = padOp.getResultType();
1077-
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1166+
RankedTensorType expandedType = expandOp.getResultType();
1167+
SmallVector<ReassociationIndices> reassociations =
1168+
expandOp.getReassociationIndices();
1169+
FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1170+
padOp, expandedType.getShape(), reassociations, rewriter);
1171+
if (failed(maybeExpandedPadding))
1172+
return failure();
1173+
PadDimInfo &expandedPadding = maybeExpandedPadding.value();
1174+
1175+
Location loc = expandOp->getLoc();
1176+
SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1177+
SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1178+
rewriter.setInsertionPointAfterValue(padOp.getSource());
1179+
SmallVector<OpFoldResult> padSrcSizes =
1180+
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
10781181
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1182+
// We know that any reassociation with multiple dims is not padded because
1183+
// of the requirements of computeExpandedPadding.
10791184
if (reInd.size() == 1) {
1080-
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081-
}
1082-
for (size_t i = 0; i < reInd.size(); ++i) {
1083-
newLow.push_back(padOp.getMixedLowPad()[idx]);
1084-
newHigh.push_back(padOp.getMixedHighPad()[idx]);
1185+
newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1186+
newExpandedSizes[reInd[0]] = padSrcSizes[idx];
10851187
}
10861188
}
1087-
1088-
Location loc = padOp->getLoc();
1089-
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1189+
RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1190+
auto newExpandOp = tensor::ExpandShapeOp::create(
1191+
rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1192+
newExpandedSizes);
1193+
RankedTensorType expandedPaddedType =
1194+
padOp.getResultType().clone(expandedPadding.paddedShape);
1195+
rewriter.setInsertionPoint(expandOp);
10901196
auto newPadOp = tensor::PadOp::create(
1091-
rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1197+
rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1198+
expandedPadding.lowPad, expandedPadding.highPad,
10921199
padOp.getConstantPaddingValue(), padOp.getNofold());
10931200

1094-
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1095-
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1201+
rewriter.replaceOp(expandOp, newPadOp.getResult());
10961202

10971203
return success();
10981204
}
@@ -1921,6 +2027,62 @@ struct FoldReshapeWithGenericOpByCollapsing
19212027
ControlFusionFn controlFoldingReshapes;
19222028
};
19232029

2030+
/// Computes the collapsed padding information for the given pad operation based
2031+
/// on the provided collapsed shape and reassociation indices. Returns a
2032+
/// PadDimInfo containing the low and high padding amounts and the collapsed
2033+
/// shape for each dimension, or failure if the collapse is not possible.
2034+
static FailureOr<PadDimInfo>
2035+
computeCollapsedPadding(tensor::PadOp padOp,
2036+
ArrayRef<ReassociationIndices> reassociations,
2037+
PatternRewriter &rewriter) {
2038+
// If the padding value depends on the index values of the pad operation,
2039+
// then it may not be valid to collapse the dimensions, since it will change
2040+
// the index values on which the padding value depends. This is not currently
2041+
// supported by the pad collapsing patterns, but it could be implemented
2042+
// similarly to the collapsing of linalg.generic ops with linalg.index ops in
2043+
// the body, as is done in `generateCollapsedIndexingRegion`.
2044+
if (!padOp.getConstantPaddingValue())
2045+
return failure();
2046+
2047+
// Collapsed dimensions cannot have padding because this can produce strided
2048+
// padding that isn't representable by a tensor.pad op. There are some special
2049+
// cases where it is possible (like collapsing unit dims), but supporting
2050+
// these cases is NYI, so disallow it for now.
2051+
ArrayRef<int64_t> low = padOp.getStaticLow();
2052+
ArrayRef<int64_t> high = padOp.getStaticHigh();
2053+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2054+
for (int64_t dim : reInd) {
2055+
if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2056+
return failure();
2057+
}
2058+
}
2059+
2060+
// Initialize padding values for collapsed tensors with zeros
2061+
ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2062+
PadDimInfo padDimInfo;
2063+
padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2064+
padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2065+
2066+
// Update padding for dimensions that are not being collapsed, and compute
2067+
// the collapsed padded shape.
2068+
SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2069+
SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
2070+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2071+
if (reInd.size() == 1) {
2072+
padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2073+
padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
2074+
}
2075+
SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
2076+
for (int64_t dim : reInd) {
2077+
collapsedSize =
2078+
collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
2079+
}
2080+
padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
2081+
}
2082+
2083+
return padDimInfo;
2084+
}
2085+
19242086
class FoldPadWithProducerReshapeOpByCollapsing
19252087
: public OpRewritePattern<tensor::PadOp> {
19262088
public:
@@ -1936,57 +2098,40 @@ class FoldPadWithProducerReshapeOpByCollapsing
19362098
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
19372099
if (!reshapeOp)
19382100
return failure();
1939-
if (!reshapeOp->hasOneUse())
1940-
return failure();
19412101

19422102
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
19432103
return rewriter.notifyMatchFailure(padOp,
19442104
"fusion blocked by control function");
19452105
}
19462106

1947-
ArrayRef<int64_t> low = padOp.getStaticLow();
1948-
ArrayRef<int64_t> high = padOp.getStaticHigh();
19492107
SmallVector<ReassociationIndices> reassociations =
19502108
reshapeOp.getReassociationIndices();
2109+
FailureOr<PadDimInfo> maybeCollapsedPadding =
2110+
computeCollapsedPadding(padOp, reassociations, rewriter);
2111+
if (failed(maybeCollapsedPadding))
2112+
return failure();
2113+
PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
19512114

1952-
for (auto reInd : reassociations) {
1953-
if (reInd.size() == 1)
1954-
continue;
1955-
if (llvm::any_of(reInd, [&](int64_t ind) {
1956-
return low[ind] != 0 || high[ind] != 0;
1957-
})) {
1958-
return failure();
1959-
}
1960-
}
1961-
1962-
SmallVector<OpFoldResult> newLow, newHigh;
1963-
RankedTensorType collapsedType = reshapeOp.getSrcType();
1964-
RankedTensorType paddedType = padOp.getResultType();
1965-
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1966-
SmallVector<OpFoldResult> expandedPaddedSizes(
1967-
getMixedValues(reshapeOp.getStaticOutputShape(),
1968-
reshapeOp.getOutputShape(), rewriter));
2115+
SmallVector<OpFoldResult> expandedPaddedSizes =
2116+
reshapeOp.getMixedOutputShape();
19692117
AffineExpr d0, d1, d2;
19702118
bindDims(rewriter.getContext(), d0, d1, d2);
19712119
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
19722120
Location loc = reshapeOp->getLoc();
1973-
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1974-
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1975-
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
2121+
for (auto [reInd, l, h] :
2122+
llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2123+
collapsedPadding.highPad)) {
19762124
if (reInd.size() == 1) {
1977-
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1978-
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
2125+
expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
19792126
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1980-
expandedPaddedSizes[reInd[0]] = paddedSize;
19812127
}
1982-
newLow.push_back(l);
1983-
newHigh.push_back(h);
19842128
}
19852129

19862130
RankedTensorType collapsedPaddedType =
1987-
paddedType.clone(collapsedPaddedShape);
2131+
padOp.getType().clone(collapsedPadding.paddedShape);
19882132
auto newPadOp = tensor::PadOp::create(
1989-
rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
2133+
rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2134+
collapsedPadding.lowPad, collapsedPadding.highPad,
19902135
padOp.getConstantPaddingValue(), padOp.getNofold());
19912136

19922137
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2145,52 @@ class FoldPadWithProducerReshapeOpByCollapsing
20002145
ControlFusionFn controlFoldingReshapes;
20012146
};
20022147

2148+
class FoldReshapeWithProducerPadOpByCollapsing
2149+
: public OpRewritePattern<tensor::CollapseShapeOp> {
2150+
public:
2151+
FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2152+
ControlFusionFn foldReshapes,
2153+
PatternBenefit benefit = 1)
2154+
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2155+
controlFoldingReshapes(std::move(foldReshapes)) {}
2156+
2157+
LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2158+
PatternRewriter &rewriter) const override {
2159+
tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2160+
if (!padOp)
2161+
return failure();
2162+
2163+
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2164+
return rewriter.notifyMatchFailure(padOp,
2165+
"fusion blocked by control function");
2166+
}
2167+
2168+
SmallVector<ReassociationIndices> reassociations =
2169+
reshapeOp.getReassociationIndices();
2170+
RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2171+
FailureOr<PadDimInfo> maybeCollapsedPadding =
2172+
computeCollapsedPadding(padOp, reassociations, rewriter);
2173+
if (failed(maybeCollapsedPadding))
2174+
return failure();
2175+
PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
2176+
2177+
Location loc = reshapeOp->getLoc();
2178+
auto newCollapseOp = tensor::CollapseShapeOp::create(
2179+
rewriter, loc, padOp.getSource(), reassociations);
2180+
2181+
auto newPadOp = tensor::PadOp::create(
2182+
rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2183+
collapsedPadding.lowPad, collapsedPadding.highPad,
2184+
padOp.getConstantPaddingValue(), padOp.getNofold());
2185+
2186+
rewriter.replaceOp(reshapeOp, newPadOp.getResult());
2187+
return success();
2188+
}
2189+
2190+
private:
2191+
ControlFusionFn controlFoldingReshapes;
2192+
};
2193+
20032194
/// Pattern to collapse dimensions.
20042195
template <typename LinalgType>
20052196
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
22392430
controlFoldingReshapes);
22402431
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
22412432
controlFoldingReshapes);
2433+
patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2434+
controlFoldingReshapes);
22422435
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
22432436
controlFoldingReshapes);
22442437
}
@@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
22502443
controlFoldingReshapes);
22512444
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
22522445
patterns.getContext(), controlFoldingReshapes);
2446+
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2447+
patterns.getContext(), controlFoldingReshapes);
22532448
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
22542449
controlFoldingReshapes);
22552450
}

0 commit comments

Comments
 (0)