Skip to content

Commit f76287c

Browse files
committed
address comments
Signed-off-by: Max Dawkins <[email protected]>
1 parent bbda86d commit f76287c

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,19 +1050,24 @@ struct PadDimInfo {
10501050

10511051
/// Computes the expanded padding information for the given pad operation based
10521052
/// on the provided expanded shape and reassociation indices. Returns a list of
1053-
/// PaddedDimInfo containing the low and high padding amounts and the padded
1053+
/// PadDimInfo containing the low and high padding amounts and the padded
10541054
/// size for each dimension, or failure if the expansion is not possible.
10551055
static FailureOr<PadDimInfo>
10561056
computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
10571057
ArrayRef<ReassociationIndices> reassociations,
10581058
PatternRewriter &rewriter) {
1059-
ArrayRef<int64_t> low = padOp.getStaticLow();
1060-
ArrayRef<int64_t> high = padOp.getStaticHigh();
1059+
// If the padding value depends on the index values of the pad operation,
1060+
// then it may not be valid to expand the dimensions, since it will change
1061+
// the index values on which the padding value depends.
1062+
if (!padOp.getConstantPaddingValue())
1063+
return failure();
10611064

10621065
// Expanded dimensions cannot have padding because the resulting padding may
10631066
// not be representable by a tensor.pad op. There are some special cases where
10641067
// it is possible (like expanding unit dims), but supporting these cases is
10651068
// NYI, so disallow it for now.
1069+
ArrayRef<int64_t> low = padOp.getStaticLow();
1070+
ArrayRef<int64_t> high = padOp.getStaticHigh();
10661071
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
10671072
if (reInd.size() != 1 && (l != 0 || h != 0))
10681073
return failure();
@@ -1101,8 +1106,6 @@ class FoldPadWithProducerReshapeOpByExpansion
11011106
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
11021107
if (!reshapeOp)
11031108
return failure();
1104-
if (!reshapeOp->hasOneUse())
1105-
return failure();
11061109

11071110
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
11081111
return rewriter.notifyMatchFailure(padOp,
@@ -1116,7 +1119,7 @@ class FoldPadWithProducerReshapeOpByExpansion
11161119
padOp, expandedType.getShape(), reassociations, rewriter);
11171120
if (failed(maybeExpandedPadding))
11181121
return failure();
1119-
PadDimInfo expandedPadding = maybeExpandedPadding.value();
1122+
PadDimInfo &expandedPadding = maybeExpandedPadding.value();
11201123

11211124
Location loc = padOp->getLoc();
11221125
RankedTensorType expandedPaddedType =
@@ -1137,12 +1140,12 @@ class FoldPadWithProducerReshapeOpByExpansion
11371140
ControlFusionFn controlFoldingReshapes;
11381141
};
11391142

1140-
class FoldExpandShapeWithProducerPadOp
1143+
class FoldReshapeWithProducerPadOpByExpansion
11411144
: public OpRewritePattern<tensor::ExpandShapeOp> {
11421145
public:
1143-
FoldExpandShapeWithProducerPadOp(MLIRContext *context,
1144-
ControlFusionFn foldReshapes,
1145-
PatternBenefit benefit = 1)
1146+
FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
1147+
ControlFusionFn foldReshapes,
1148+
PatternBenefit benefit = 1)
11461149
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
11471150
controlFoldingReshapes(std::move(foldReshapes)) {}
11481151

@@ -1151,8 +1154,6 @@ class FoldExpandShapeWithProducerPadOp
11511154
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
11521155
if (!padOp)
11531156
return failure();
1154-
if (!padOp->hasOneUse())
1155-
return failure();
11561157

11571158
if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
11581159
return rewriter.notifyMatchFailure(expandOp,
@@ -1166,7 +1167,7 @@ class FoldExpandShapeWithProducerPadOp
11661167
padOp, expandedType.getShape(), reassociations, rewriter);
11671168
if (failed(maybeExpandedPadding))
11681169
return failure();
1169-
PadDimInfo expandedPadding = maybeExpandedPadding.value();
1170+
PadDimInfo &expandedPadding = maybeExpandedPadding.value();
11701171

11711172
Location loc = expandOp->getLoc();
11721173
SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
@@ -2031,13 +2032,18 @@ static FailureOr<PadDimInfo>
20312032
computeCollapsedPadding(tensor::PadOp padOp,
20322033
ArrayRef<ReassociationIndices> reassociations,
20332034
PatternRewriter &rewriter) {
2034-
ArrayRef<int64_t> low = padOp.getStaticLow();
2035-
ArrayRef<int64_t> high = padOp.getStaticHigh();
2035+
// If the padding value depends on the index values of the pad operation,
2036+
// then it may not be valid to collapse the dimensions, since it will change
2037+
// the index values on which the padding value depends.
2038+
if (!padOp.getConstantPaddingValue())
2039+
return failure();
20362040

20372041
// Collapsed dimensions cannot have padding because this can produce strided
20382042
// padding that isn't representable by a tensor.pad op. There are some special
2039-
// cases where it it possible (like collapsing unit dims), but supporting
2043+
// cases where it is possible (like collapsing unit dims), but supporting
20402044
// these cases is NYI, so disallow it for now.
2045+
ArrayRef<int64_t> low = padOp.getStaticLow();
2046+
ArrayRef<int64_t> high = padOp.getStaticHigh();
20412047
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
20422048
for (int64_t dim : reInd) {
20432049
if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
@@ -2053,10 +2059,12 @@ computeCollapsedPadding(tensor::PadOp padOp,
20532059

20542060
// Update padding for dimensions that are not being collapsed, and compute
20552061
// the collapsed padded shape.
2062+
SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
2063+
SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
20562064
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
20572065
if (reInd.size() == 1) {
2058-
padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
2059-
padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
2066+
padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
2067+
padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
20602068
}
20612069
SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
20622070
for (int64_t dim : reInd) {
@@ -2084,8 +2092,6 @@ class FoldPadWithProducerReshapeOpByCollapsing
20842092
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
20852093
if (!reshapeOp)
20862094
return failure();
2087-
if (!reshapeOp->hasOneUse())
2088-
return failure();
20892095

20902096
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
20912097
return rewriter.notifyMatchFailure(padOp,
@@ -2098,7 +2104,7 @@ class FoldPadWithProducerReshapeOpByCollapsing
20982104
computeCollapsedPadding(padOp, reassociations, rewriter);
20992105
if (failed(maybeCollapsedPadding))
21002106
return failure();
2101-
PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
2107+
PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
21022108

21032109
SmallVector<OpFoldResult> expandedPaddedSizes =
21042110
reshapeOp.getMixedOutputShape();
@@ -2147,8 +2153,6 @@ class FoldReshapeWithProducerPadOpByCollapsing
21472153
tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
21482154
if (!padOp)
21492155
return failure();
2150-
if (!padOp->hasOneUse())
2151-
return failure();
21522156

21532157
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
21542158
return rewriter.notifyMatchFailure(padOp,
@@ -2162,7 +2166,7 @@ class FoldReshapeWithProducerPadOpByCollapsing
21622166
computeCollapsedPadding(padOp, reassociations, rewriter);
21632167
if (failed(maybeCollapsedPadding))
21642168
return failure();
2165-
PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
2169+
PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
21662170

21672171
Location loc = reshapeOp->getLoc();
21682172
auto newCollapseOp = tensor::CollapseShapeOp::create(
@@ -2420,8 +2424,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
24202424
controlFoldingReshapes);
24212425
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
24222426
controlFoldingReshapes);
2423-
patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
2424-
controlFoldingReshapes);
2427+
patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
2428+
controlFoldingReshapes);
24252429
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
24262430
controlFoldingReshapes);
24272431
}

0 commit comments

Comments
 (0)