-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] Add shape propagation through tensor.pad #136681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
0d8c636
cf1e560
57ec657
737d4a4
d8ca036
17a2447
9ee8e08
3b91645
9c38ad5
0faf084
9cbd032
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1100,6 +1100,198 @@ class FoldPadWithProducerReshapeOpByExpansion | |
| ControlFusionFn controlFoldingReshapes; | ||
| }; | ||
|
|
||
| /// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op | ||
| /// by bubbling the expand_shape before the pad. | ||
| struct FoldReshapeWithProducerPadOpByExpansion | ||
| : public OpRewritePattern<tensor::ExpandShapeOp> { | ||
|
|
||
| FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context, | ||
| ControlFusionFn foldReshapes, | ||
| PatternBenefit benefit = 1) | ||
| : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), | ||
| controlFoldingReshapes(std::move(foldReshapes)) {} | ||
|
|
||
| LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, | ||
| PatternRewriter &rewriter) const override { | ||
| tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>(); | ||
| if (!padOp) | ||
| return failure(); | ||
|
|
||
| if (!padOp->hasOneUse()) | ||
| return failure(); | ||
|
|
||
| if (!controlFoldingReshapes(&expandOp.getSrcMutable())) { | ||
| return rewriter.notifyMatchFailure(expandOp, | ||
| "fusion blocked by control function"); | ||
| } | ||
|
|
||
| Value constantPaddingValue = padOp.getConstantPaddingValue(); | ||
| if (!constantPaddingValue) { | ||
| return rewriter.notifyMatchFailure( | ||
| expandOp, "cannot fold with non-constant padding value"); | ||
| } | ||
|
|
||
| SmallVector<ReassociationIndices> reassociations = | ||
| expandOp.getReassociationIndices(); | ||
| SmallVector<OpFoldResult> low = padOp.getMixedLowPad(); | ||
| SmallVector<OpFoldResult> high = padOp.getMixedHighPad(); | ||
|
|
||
| for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { | ||
| if (reInd.size() > 1 && | ||
| (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0))) | ||
| return rewriter.notifyMatchFailure( | ||
| expandOp, "fusion blocked by non-zero padding"); | ||
| } | ||
|
|
||
| SmallVector<OpFoldResult> newLow, newHigh; | ||
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
| newLow.append(reInd.size(), low[idx]); | ||
| newHigh.append(reInd.size(), high[idx]); | ||
| } | ||
|
||
|
|
||
| Location loc = expandOp.getLoc(); | ||
| ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape(); | ||
| SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape(); | ||
|
|
||
| for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { | ||
| OpFoldResult l = low[inDimIdx]; | ||
| OpFoldResult h = high[inDimIdx]; | ||
Max191 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { | ||
| assert(reInd.size() == 1 && "expected single dimension"); | ||
| expandedShape[reInd[0]] = | ||
| tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx); | ||
| ; | ||
| } | ||
| } | ||
|
|
||
| for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) { | ||
| if (dimSize == ShapedType::kDynamic && | ||
| !isa<Value>(expandedShape[outDimIdx]) && | ||
| !isa<Attribute>(expandedShape[outDimIdx])) { | ||
| expandedShape[outDimIdx] = | ||
| tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx); | ||
| } | ||
| } | ||
|
||
|
|
||
| SmallVector<int64_t> staticExpandedShape; | ||
| std::tie(staticExpandedShape, std::ignore) = | ||
| decomposeMixedValues(expandedShape); | ||
|
|
||
| auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( | ||
| loc, | ||
| RankedTensorType::get(staticExpandedShape, | ||
| padOp.getSource().getType().getElementType()), | ||
| padOp.getSource(), reassociations, expandedShape); | ||
|
|
||
| rewriter.replaceOpWithNewOp<tensor::PadOp>( | ||
| expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh, | ||
| padOp.getConstantPaddingValue(), padOp.getNofold()); | ||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| ControlFusionFn controlFoldingReshapes; | ||
| }; | ||
|
|
||
| /// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op | ||
| /// by bubbling the collapse_shape before the pad. | ||
| struct FoldReshapeWithProducerPadOpByCollapsing | ||
| : public OpRewritePattern<tensor::CollapseShapeOp> { | ||
|
|
||
| FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context, | ||
| ControlFusionFn foldReshapes, | ||
| PatternBenefit benefit = 1) | ||
| : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit), | ||
| controlFoldingReshapes(std::move(foldReshapes)) {} | ||
|
|
||
| LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, | ||
| PatternRewriter &rewriter) const override { | ||
| tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>(); | ||
|
|
||
| if (!padOp) | ||
| return failure(); | ||
|
|
||
| if (!padOp->hasOneUse()) | ||
| return failure(); | ||
|
|
||
| if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) { | ||
| return rewriter.notifyMatchFailure(collapseOp, | ||
| "fusion blocked by control function"); | ||
| } | ||
|
|
||
| Value constantPaddingValue = padOp.getConstantPaddingValue(); | ||
| if (!constantPaddingValue) { | ||
| return rewriter.notifyMatchFailure( | ||
| collapseOp, "cannot fold with non-constant padding value"); | ||
| } | ||
|
|
||
| SmallVector<ReassociationIndices> reassociations = | ||
| collapseOp.getReassociationIndices(); | ||
| SmallVector<OpFoldResult> low = padOp.getMixedLowPad(); | ||
| SmallVector<OpFoldResult> high = padOp.getMixedHighPad(); | ||
|
|
||
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
| if (reInd.size() > 1) { | ||
| for (auto dimIdx : reInd) { | ||
| if (!isConstantIntValue(low[dimIdx], 0) || | ||
| !isConstantIntValue(high[dimIdx], 0)) { | ||
| return failure(); | ||
| } | ||
| } | ||
| } | ||
ita9naiwa marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| SmallVector<OpFoldResult> newLow, newHigh; | ||
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
| newLow.push_back(low[reInd[0]]); | ||
| newHigh.push_back(high[reInd[0]]); | ||
| } | ||
|
|
||
|
||
| Location loc = collapseOp.getLoc(); | ||
| auto resultType = collapseOp.getResultType(); | ||
|
|
||
| ArrayRef<int64_t> finalShape = collapseOp.getResultType().getShape(); | ||
|
|
||
| SmallVector<OpFoldResult> collapsedShape; | ||
| for (int64_t dimSize : finalShape) { | ||
| if (dimSize == ShapedType::kDynamic) { | ||
| collapsedShape.push_back(OpFoldResult{}); | ||
| } else { | ||
| collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize)); | ||
| } | ||
| } | ||
|
||
|
|
||
| for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) { | ||
| OpFoldResult l = low[reInd[0]]; | ||
| OpFoldResult h = high[reInd[0]]; | ||
|
|
||
| if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) { | ||
| collapsedShape[inDimIdx] = | ||
| tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]); | ||
| } | ||
| } | ||
|
|
||
| SmallVector<int64_t> staticCollapsedShape; | ||
| std::tie(staticCollapsedShape, std::ignore) = | ||
| decomposeMixedValues(collapsedShape); | ||
|
|
||
| auto newCollapseType = RankedTensorType::get( | ||
| staticCollapsedShape, padOp.getSource().getType().getElementType()); | ||
| auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( | ||
| loc, newCollapseType, padOp.getSource(), reassociations); | ||
|
|
||
| rewriter.replaceOpWithNewOp<tensor::PadOp>( | ||
| collapseOp, resultType, newCollapseOp.getResult(), newLow, newHigh, | ||
| padOp.getConstantPaddingValue(), padOp.getNofold()); | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| ControlFusionFn controlFoldingReshapes; | ||
| }; | ||
|
|
||
| /// Pattern to fold a tensor.expand_shape op with its producer generic op | ||
| /// by expanding the dimensionality of the loop in the producer op. | ||
| struct FoldReshapeWithGenericOpByExpansion | ||
|
|
@@ -2235,6 +2427,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( | |
| controlFoldingReshapes); | ||
| patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(), | ||
| controlFoldingReshapes); | ||
| patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(), | ||
| controlFoldingReshapes); | ||
| patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), | ||
| controlFoldingReshapes); | ||
| } | ||
|
|
@@ -2246,6 +2440,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( | |
| controlFoldingReshapes); | ||
| patterns.add<FoldPadWithProducerReshapeOpByCollapsing>( | ||
| patterns.getContext(), controlFoldingReshapes); | ||
| patterns.add<FoldReshapeWithProducerPadOpByCollapsing>( | ||
| patterns.getContext(), controlFoldingReshapes); | ||
|
|
||
| patterns.add<FoldReshapeWithProducerPadOpByCollapsing>( | ||
| patterns.getContext(), controlFoldingReshapes); | ||
| patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(), | ||
| controlFoldingReshapes); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, "folding" would mean that
tensor.expand_shapedisappears (i.e. is folded away), but that's not what is happening here, is it? This is merely "bubbling up".Please update the description accordingly and add some example IR before and after. As an example: https://github.com/banach-space/llvm-project/blob/7d35eb58959c0ab398a9739f38bfb9754c5ba5e5/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp#L305-L317