Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 248 additions & 53 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,62 @@ class FoldWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};

/// Carries information about a padded dimension.
struct PadDimInfo {
// The resulting shape after padding each dimension.
SmallVector<int64_t> paddedShape;

// Low and high padding amounts for each dimension.
SmallVector<OpFoldResult> lowPad;
SmallVector<OpFoldResult> highPad;
};

/// Computes the expanded padding information for the given pad operation based
/// on the provided expanded shape and reassociation indices. Returns a list of
/// PadDimInfo containing the low and high padding amounts and the padded
/// size for each dimension, or failure if the expansion is not possible.
static FailureOr<PadDimInfo>
computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociations,
PatternRewriter &rewriter) {
// If the padding value depends on the index values of the pad operation,
// then it may not be valid to expand the dimensions, since it will change
// the index values on which the padding value depends. This is not currently
// supported by the pad expansion patterns, but it could be implemented
// similarly to the expansion of linalg.generic ops with linalg.index ops in
// the body, as is done in `updateExpandedGenericOpRegion`.
if (!padOp.getConstantPaddingValue())
return failure();

// Expanded dimensions cannot have padding because the resulting padding may
// not be representable by a tensor.pad op. There are some special cases where
// it is possible (like expanding unit dims), but supporting these cases is
// NYI, so disallow it for now.
ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
if (reInd.size() != 1 && (l != 0 || h != 0))
return failure();
}

SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
PadDimInfo padDimInfo;
padDimInfo.paddedShape.assign(expandedShape);
padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() == 1) {
padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
}
}

return padDimInfo;
}

class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
Expand All @@ -1053,46 +1109,96 @@ class FoldPadWithProducerReshapeOpByExpansion
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
return failure();
if (!reshapeOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}

ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
RankedTensorType expandedType = reshapeOp.getSrcType();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
padOp, expandedType.getShape(), reassociations, rewriter);
if (failed(maybeExpandedPadding))
return failure();
PadDimInfo &expandedPadding = maybeExpandedPadding.value();

for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
if (reInd.size() != 1 && (l != 0 || h != 0))
return failure();
Location loc = padOp->getLoc();
RankedTensorType expandedPaddedType =
padOp.getResultType().clone(expandedPadding.paddedShape);

auto newPadOp = tensor::PadOp::create(
rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
expandedPadding.lowPad, expandedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);

return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

class FoldReshapeWithProducerPadOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {
public:
FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();

if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(expandOp,
"fusion blocked by control function");
}

SmallVector<OpFoldResult> newLow, newHigh;
RankedTensorType expandedType = reshapeOp.getSrcType();
RankedTensorType paddedType = padOp.getResultType();
SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
RankedTensorType expandedType = expandOp.getResultType();
SmallVector<ReassociationIndices> reassociations =
expandOp.getReassociationIndices();
FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
padOp, expandedType.getShape(), reassociations, rewriter);
if (failed(maybeExpandedPadding))
return failure();
PadDimInfo &expandedPadding = maybeExpandedPadding.value();

Location loc = expandOp->getLoc();
SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
SmallVector<int64_t> newExpandedShape(expandedType.getShape());
rewriter.setInsertionPointAfterValue(padOp.getSource());
SmallVector<OpFoldResult> padSrcSizes =
tensor::getMixedSizes(rewriter, loc, padOp.getSource());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
// We know that any reassociation with multiple dims is not padded because
// of the requirements of computeExpandedPadding.
if (reInd.size() == 1) {
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
}
for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(padOp.getMixedLowPad()[idx]);
newHigh.push_back(padOp.getMixedHighPad()[idx]);
newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
newExpandedSizes[reInd[0]] = padSrcSizes[idx];
}
}

Location loc = padOp->getLoc();
RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
auto newExpandOp = tensor::ExpandShapeOp::create(
rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
newExpandedSizes);
RankedTensorType expandedPaddedType =
padOp.getResultType().clone(expandedPadding.paddedShape);
rewriter.setInsertionPoint(expandOp);
auto newPadOp = tensor::PadOp::create(
rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
expandedPadding.lowPad, expandedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
rewriter.replaceOp(expandOp, newPadOp.getResult());

return success();
}
Expand Down Expand Up @@ -1921,6 +2027,62 @@ struct FoldReshapeWithGenericOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};

/// Computes the collapsed padding information for the given pad operation based
/// on the provided collapsed shape and reassociation indices. Returns a
/// PadDimInfo containing the low and high padding amounts and the collapsed
/// shape for each dimension, or failure if the collapse is not possible.
static FailureOr<PadDimInfo>
computeCollapsedPadding(tensor::PadOp padOp,
ArrayRef<ReassociationIndices> reassociations,
PatternRewriter &rewriter) {
// If the padding value depends on the index values of the pad operation,
// then it may not be valid to collapse the dimensions, since it will change
// the index values on which the padding value depends. This is not currently
// supported by the pad collapsing patterns, but it could be implemented
// similarly to the collapsing of linalg.generic ops with linalg.index ops in
// the body, as is done in `generateCollapsedIndexingRegion`.
if (!padOp.getConstantPaddingValue())
return failure();

// Collapsed dimensions cannot have padding because this can produce strided
// padding that isn't representable by a tensor.pad op. There are some special
// cases where it is possible (like collapsing unit dims), but supporting
// these cases is NYI, so disallow it for now.
ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
for (int64_t dim : reInd) {
if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
return failure();
}
}

// Initialize padding values for collapsed tensors with zeros
ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
PadDimInfo padDimInfo;
padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));

// Update padding for dimensions that are not being collapsed, and compute
// the collapsed padded shape.
SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() == 1) {
padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
}
SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
for (int64_t dim : reInd) {
collapsedSize =
collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
}
padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
}

return padDimInfo;
}

class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
Expand All @@ -1936,57 +2098,40 @@ class FoldPadWithProducerReshapeOpByCollapsing
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
return failure();
if (!reshapeOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}

ArrayRef<int64_t> low = padOp.getStaticLow();
ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
FailureOr<PadDimInfo> maybeCollapsedPadding =
computeCollapsedPadding(padOp, reassociations, rewriter);
if (failed(maybeCollapsedPadding))
return failure();
PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();

for (auto reInd : reassociations) {
if (reInd.size() == 1)
continue;
if (llvm::any_of(reInd, [&](int64_t ind) {
return low[ind] != 0 || high[ind] != 0;
})) {
return failure();
}
}

SmallVector<OpFoldResult> newLow, newHigh;
RankedTensorType collapsedType = reshapeOp.getSrcType();
RankedTensorType paddedType = padOp.getResultType();
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
SmallVector<OpFoldResult> expandedPaddedSizes(
getMixedValues(reshapeOp.getStaticOutputShape(),
reshapeOp.getOutputShape(), rewriter));
SmallVector<OpFoldResult> expandedPaddedSizes =
reshapeOp.getMixedOutputShape();
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
for (auto [reInd, l, h] :
llvm::zip_equal(reassociations, collapsedPadding.lowPad,
collapsedPadding.highPad)) {
if (reInd.size() == 1) {
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
expandedPaddedSizes[reInd[0]] = paddedSize;
}
newLow.push_back(l);
newHigh.push_back(h);
}

RankedTensorType collapsedPaddedType =
paddedType.clone(collapsedPaddedShape);
padOp.getType().clone(collapsedPadding.paddedShape);
auto newPadOp = tensor::PadOp::create(
rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
collapsedPadding.lowPad, collapsedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
Expand All @@ -2000,6 +2145,52 @@ class FoldPadWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};

class FoldReshapeWithProducerPadOpByCollapsing
: public OpRewritePattern<tensor::CollapseShapeOp> {
public:
FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();

if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(padOp,
"fusion blocked by control function");
}

SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
FailureOr<PadDimInfo> maybeCollapsedPadding =
computeCollapsedPadding(padOp, reassociations, rewriter);
if (failed(maybeCollapsedPadding))
return failure();
PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();

Location loc = reshapeOp->getLoc();
auto newCollapseOp = tensor::CollapseShapeOp::create(
rewriter, loc, padOp.getSource(), reassociations);

auto newPadOp = tensor::PadOp::create(
rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
collapsedPadding.lowPad, collapsedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOp(reshapeOp, newPadOp.getResult());
return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
Expand Down Expand Up @@ -2239,6 +2430,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
Expand All @@ -2250,6 +2443,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
Expand Down
Loading