Skip to content

Commit bbda86d

Browse files
committed
[mlir] Add missing pad reshape propagation patterns
Signed-off-by: Max Dawkins <[email protected]>
1 parent aa3f930 commit bbda86d

File tree

3 files changed

+314
-49
lines changed

3 files changed

+314
-49
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 234 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
10381038
ControlFusionFn controlFoldingReshapes;
10391039
};
10401040

1041+
/// Carries information about a padded dimension.
1042+
struct PadDimInfo {
1043+
// The resulting shape after padding each dimension.
1044+
SmallVector<int64_t> paddedShape;
1045+
1046+
// Low and high padding amounts for each dimension.
1047+
SmallVector<OpFoldResult> lowPad;
1048+
SmallVector<OpFoldResult> highPad;
1049+
};
1050+
1051+
/// Computes the expanded padding information for the given pad operation based
1052+
/// on the provided expanded shape and reassociation indices. Returns a list of
1053+
/// PaddedDimInfo containing the low and high padding amounts and the padded
1054+
/// size for each dimension, or failure if the expansion is not possible.
1055+
static FailureOr<PadDimInfo>
1056+
computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
1057+
ArrayRef<ReassociationIndices> reassociations,
1058+
PatternRewriter &rewriter) {
1059+
ArrayRef<int64_t> low = padOp.getStaticLow();
1060+
ArrayRef<int64_t> high = padOp.getStaticHigh();
1061+
1062+
// Expanded dimensions cannot have padding because the resulting padding may
1063+
// not be representable by a tensor.pad op. There are some special cases where
1064+
// it is possible (like expanding unit dims), but supporting these cases is
1065+
// NYI, so disallow it for now.
1066+
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1067+
if (reInd.size() != 1 && (l != 0 || h != 0))
1068+
return failure();
1069+
}
1070+
1071+
SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
1072+
SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
1073+
ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
1074+
PadDimInfo padDimInfo;
1075+
padDimInfo.paddedShape.assign(expandedShape);
1076+
padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1077+
padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
1078+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1079+
if (reInd.size() == 1) {
1080+
padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
1081+
padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
1082+
padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
1083+
}
1084+
}
1085+
1086+
return padDimInfo;
1087+
}
1088+
10411089
class FoldPadWithProducerReshapeOpByExpansion
10421090
: public OpRewritePattern<tensor::PadOp> {
10431091
public:
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
10611109
"fusion blocked by control function");
10621110
}
10631111

1064-
ArrayRef<int64_t> low = padOp.getStaticLow();
1065-
ArrayRef<int64_t> high = padOp.getStaticHigh();
1112+
RankedTensorType expandedType = reshapeOp.getSrcType();
10661113
SmallVector<ReassociationIndices> reassociations =
10671114
reshapeOp.getReassociationIndices();
1115+
FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1116+
padOp, expandedType.getShape(), reassociations, rewriter);
1117+
if (failed(maybeExpandedPadding))
1118+
return failure();
1119+
PadDimInfo expandedPadding = maybeExpandedPadding.value();
10681120

1069-
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1070-
if (reInd.size() != 1 && (l != 0 || h != 0))
1071-
return failure();
1121+
Location loc = padOp->getLoc();
1122+
RankedTensorType expandedPaddedType =
1123+
padOp.getResultType().clone(expandedPadding.paddedShape);
1124+
1125+
auto newPadOp = tensor::PadOp::create(
1126+
rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
1127+
expandedPadding.lowPad, expandedPadding.highPad,
1128+
padOp.getConstantPaddingValue(), padOp.getNofold());
1129+
1130+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1131+
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1132+
1133+
return success();
1134+
}
1135+
1136+
private:
1137+
ControlFusionFn controlFoldingReshapes;
1138+
};
1139+
1140+
class FoldExpandShapeWithProducerPadOp
1141+
: public OpRewritePattern<tensor::ExpandShapeOp> {
1142+
public:
1143+
FoldExpandShapeWithProducerPadOp(MLIRContext *context,
1144+
ControlFusionFn foldReshapes,
1145+
PatternBenefit benefit = 1)
1146+
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1147+
controlFoldingReshapes(std::move(foldReshapes)) {}
1148+
1149+
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
1150+
PatternRewriter &rewriter) const override {
1151+
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
1152+
if (!padOp)
1153+
return failure();
1154+
if (!padOp->hasOneUse())
1155+
return failure();
1156+
1157+
if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
1158+
return rewriter.notifyMatchFailure(expandOp,
1159+
"fusion blocked by control function");
10721160
}
10731161

1074-
SmallVector<OpFoldResult> newLow, newHigh;
1075-
RankedTensorType expandedType = reshapeOp.getSrcType();
1076-
RankedTensorType paddedType = padOp.getResultType();
1077-
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1162+
RankedTensorType expandedType = expandOp.getResultType();
1163+
SmallVector<ReassociationIndices> reassociations =
1164+
expandOp.getReassociationIndices();
1165+
FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
1166+
padOp, expandedType.getShape(), reassociations, rewriter);
1167+
if (failed(maybeExpandedPadding))
1168+
return failure();
1169+
PadDimInfo expandedPadding = maybeExpandedPadding.value();
1170+
1171+
Location loc = expandOp->getLoc();
1172+
SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
1173+
SmallVector<int64_t> newExpandedShape(expandedType.getShape());
1174+
rewriter.setInsertionPointAfterValue(padOp.getSource());
1175+
SmallVector<OpFoldResult> padSrcSizes =
1176+
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
10781177
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1178+
// We know that any reassociation with multiple dims is not padded because
1179+
// of the requirements of computeExpandedPadding.
10791180
if (reInd.size() == 1) {
1080-
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1081-
}
1082-
for (size_t i = 0; i < reInd.size(); ++i) {
1083-
newLow.push_back(padOp.getMixedLowPad()[idx]);
1084-
newHigh.push_back(padOp.getMixedHighPad()[idx]);
1181+
newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
1182+
newExpandedSizes[reInd[0]] = padSrcSizes[idx];
10851183
}
10861184
}
1087-
1088-
Location loc = padOp->getLoc();
1089-
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1185+
RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
1186+
auto newExpandOp = tensor::ExpandShapeOp::create(
1187+
rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
1188+
newExpandedSizes);
1189+
RankedTensorType expandedPaddedType =
1190+
padOp.getResultType().clone(expandedPadding.paddedShape);
1191+
rewriter.setInsertionPoint(expandOp);
10901192
auto newPadOp = tensor::PadOp::create(
1091-
rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1193+
rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
1194+
expandedPadding.lowPad, expandedPadding.highPad,
10921195
padOp.getConstantPaddingValue(), padOp.getNofold());
10931196

1094-
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1095-
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1197+
rewriter.replaceOp(expandOp, newPadOp.getResult());
10961198

10971199
return success();
10981200
}
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
19212023
ControlFusionFn controlFoldingReshapes;
19222024
};
19232025

2026+
/// Computes the collapsed padding information for the given pad operation based
2027+
/// on the provided collapsed shape and reassociation indices. Returns a
2028+
/// PadDimInfo containing the low and high padding amounts and the collapsed
2029+
/// shape for each dimension, or failure if the collapse is not possible.
2030+
static FailureOr<PadDimInfo>
2031+
computeCollapsedPadding(tensor::PadOp padOp,
2032+
ArrayRef<ReassociationIndices> reassociations,
2033+
PatternRewriter &rewriter) {
2034+
ArrayRef<int64_t> low = padOp.getStaticLow();
2035+
ArrayRef<int64_t> high = padOp.getStaticHigh();
2036+
2037+
// Collapsed dimensions cannot have padding because this can produce strided
2038+
// padding that isn't representable by a tensor.pad op. There are some special
2039+
// cases where it it possible (like collapsing unit dims), but supporting
2040+
// these cases is NYI, so disallow it for now.
2041+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2042+
for (int64_t dim : reInd) {
2043+
if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
2044+
return failure();
2045+
}
2046+
}
2047+
2048+
// Initialize padding values for collapsed tensors with zeros
2049+
ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
2050+
PadDimInfo padDimInfo;
2051+
padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2052+
padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
2053+
2054+
// Update padding for dimensions that are not being collapsed, and compute
2055+
// the collapsed padded shape.
2056+
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
2057+
if (reInd.size() == 1) {
2058+
padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
2059+
padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
2060+
}
2061+
SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
2062+
for (int64_t dim : reInd) {
2063+
collapsedSize =
2064+
collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
2065+
}
2066+
padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
2067+
}
2068+
2069+
return padDimInfo;
2070+
}
2071+
19242072
class FoldPadWithProducerReshapeOpByCollapsing
19252073
: public OpRewritePattern<tensor::PadOp> {
19262074
public:
@@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing
19442092
"fusion blocked by control function");
19452093
}
19462094

1947-
ArrayRef<int64_t> low = padOp.getStaticLow();
1948-
ArrayRef<int64_t> high = padOp.getStaticHigh();
19492095
SmallVector<ReassociationIndices> reassociations =
19502096
reshapeOp.getReassociationIndices();
2097+
FailureOr<PadDimInfo> maybeCollapsedPadding =
2098+
computeCollapsedPadding(padOp, reassociations, rewriter);
2099+
if (failed(maybeCollapsedPadding))
2100+
return failure();
2101+
PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
19512102

1952-
for (auto reInd : reassociations) {
1953-
if (reInd.size() == 1)
1954-
continue;
1955-
if (llvm::any_of(reInd, [&](int64_t ind) {
1956-
return low[ind] != 0 || high[ind] != 0;
1957-
})) {
1958-
return failure();
1959-
}
1960-
}
1961-
1962-
SmallVector<OpFoldResult> newLow, newHigh;
1963-
RankedTensorType collapsedType = reshapeOp.getSrcType();
1964-
RankedTensorType paddedType = padOp.getResultType();
1965-
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1966-
SmallVector<OpFoldResult> expandedPaddedSizes(
1967-
getMixedValues(reshapeOp.getStaticOutputShape(),
1968-
reshapeOp.getOutputShape(), rewriter));
2103+
SmallVector<OpFoldResult> expandedPaddedSizes =
2104+
reshapeOp.getMixedOutputShape();
19692105
AffineExpr d0, d1, d2;
19702106
bindDims(rewriter.getContext(), d0, d1, d2);
19712107
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
19722108
Location loc = reshapeOp->getLoc();
1973-
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1974-
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1975-
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
2109+
for (auto [reInd, l, h] :
2110+
llvm::zip_equal(reassociations, collapsedPadding.lowPad,
2111+
collapsedPadding.highPad)) {
19762112
if (reInd.size() == 1) {
1977-
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1978-
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
2113+
expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
19792114
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1980-
expandedPaddedSizes[reInd[0]] = paddedSize;
19812115
}
1982-
newLow.push_back(l);
1983-
newHigh.push_back(h);
19842116
}
19852117

19862118
RankedTensorType collapsedPaddedType =
1987-
paddedType.clone(collapsedPaddedShape);
2119+
padOp.getType().clone(collapsedPadding.paddedShape);
19882120
auto newPadOp = tensor::PadOp::create(
1989-
rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
2121+
rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
2122+
collapsedPadding.lowPad, collapsedPadding.highPad,
19902123
padOp.getConstantPaddingValue(), padOp.getNofold());
19912124

19922125
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing
20002133
ControlFusionFn controlFoldingReshapes;
20012134
};
20022135

2136+
class FoldReshapeWithProducerPadOpByCollapsing
2137+
: public OpRewritePattern<tensor::CollapseShapeOp> {
2138+
public:
2139+
FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
2140+
ControlFusionFn foldReshapes,
2141+
PatternBenefit benefit = 1)
2142+
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
2143+
controlFoldingReshapes(std::move(foldReshapes)) {}
2144+
2145+
LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
2146+
PatternRewriter &rewriter) const override {
2147+
tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
2148+
if (!padOp)
2149+
return failure();
2150+
if (!padOp->hasOneUse())
2151+
return failure();
2152+
2153+
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
2154+
return rewriter.notifyMatchFailure(padOp,
2155+
"fusion blocked by control function");
2156+
}
2157+
2158+
SmallVector<ReassociationIndices> reassociations =
2159+
reshapeOp.getReassociationIndices();
2160+
RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
2161+
FailureOr<PadDimInfo> maybeCollapsedPadding =
2162+
computeCollapsedPadding(padOp, reassociations, rewriter);
2163+
if (failed(maybeCollapsedPadding))
2164+
return failure();
2165+
PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
2166+
2167+
Location loc = reshapeOp->getLoc();
2168+
auto newCollapseOp = tensor::CollapseShapeOp::create(
2169+
rewriter, loc, padOp.getSource(), reassociations);
2170+
2171+
auto newPadOp = tensor::PadOp::create(
2172+
rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
2173+
collapsedPadding.lowPad, collapsedPadding.highPad,
2174+
padOp.getConstantPaddingValue(), padOp.getNofold());
2175+
2176+
rewriter.replaceOp(reshapeOp, newPadOp.getResult());
2177+
return success();
2178+
}
2179+
2180+
private:
2181+
ControlFusionFn controlFoldingReshapes;
2182+
};
2183+
20032184
/// Pattern to collapse dimensions.
20042185
template <typename LinalgType>
20052186
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
22392420
controlFoldingReshapes);
22402421
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
22412422
controlFoldingReshapes);
2423+
patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
2424+
controlFoldingReshapes);
22422425
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
22432426
controlFoldingReshapes);
22442427
}
@@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
22502433
controlFoldingReshapes);
22512434
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
22522435
patterns.getContext(), controlFoldingReshapes);
2436+
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
2437+
patterns.getContext(), controlFoldingReshapes);
22532438
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
22542439
controlFoldingReshapes);
22552440
}

0 commit comments

Comments
 (0)