-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] Add missing pad reshape propagation patterns #168888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (Max191) ChangesThe existing Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168888.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..8c5a0c1474408 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+ // The resulting shape after padding each dimension.
+ SmallVector<int64_t> paddedShape;
+
+ // Low and high padding amounts for each dimension.
+ SmallVector<OpFoldResult> lowPad;
+ SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PaddedDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ // Expanded dimensions cannot have padding because the resulting padding may
+ // not be representable by a tensor.pad op. There are some special cases where
+ // it is possible (like expanding unit dims), but supporting these cases is
+ // NYI, so disallow it for now.
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+ ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.paddedShape.assign(expandedShape);
+ padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+ padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+ padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+ }
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ RankedTensorType expandedType = reshapeOp.getSrcType();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo expandedPadding = maybeExpandedPadding.value();
- for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
- if (reInd.size() != 1 && (l != 0 || h != 0))
- return failure();
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+ expandedPadding.lowPad, expandedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldExpandShapeWithProducerPadOp
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+ FoldExpandShapeWithProducerPadOp(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
}
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType expandedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+ RankedTensorType expandedType = expandOp.getResultType();
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo expandedPadding = maybeExpandedPadding.value();
+
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
+ SmallVector<int64_t> newExpandedShape(expandedType.getShape());
+ rewriter.setInsertionPointAfterValue(padOp.getSource());
+ SmallVector<OpFoldResult> padSrcSizes =
+ tensor::getMixedSizes(rewriter, loc, padOp.getSource());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ // We know that any reassociation with multiple dims is not padded because
+ // of the requirements of computeExpandedPadding.
if (reInd.size() == 1) {
- expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
- }
- for (size_t i = 0; i < reInd.size(); ++i) {
- newLow.push_back(padOp.getMixedLowPad()[idx]);
- newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
+ newExpandedSizes[reInd[0]] = padSrcSizes[idx];
}
}
-
- Location loc = padOp->getLoc();
- RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+ RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
+ auto newExpandOp = tensor::ExpandShapeOp::create(
+ rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
+ newExpandedSizes);
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+ rewriter.setInsertionPoint(expandOp);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
+ expandedPadding.lowPad, expandedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
return success();
}
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ // Collapsed dimensions cannot have padding because this can produce strided
+ // padding that isn't representable by a tensor.pad op. There are some special
+ // cases where it it possible (like collapsing unit dims), but supporting
+ // these cases is NYI, so disallow it for now.
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (int64_t dim : reInd) {
+ if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+ return failure();
+ }
+ }
+
+ // Initialize padding values for collapsed tensors with zeros
+ ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+ // Update padding for dimensions that are not being collapsed, and compute
+ // the collapsed padded shape.
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
+ padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
+ }
+ SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
+ for (int64_t dim : reInd) {
+ collapsedSize =
+ collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
+ }
+ padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
- for (auto reInd : reassociations) {
- if (reInd.size() == 1)
- continue;
- if (llvm::any_of(reInd, [&](int64_t ind) {
- return low[ind] != 0 || high[ind] != 0;
- })) {
- return failure();
- }
- }
-
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType collapsedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
- SmallVector<OpFoldResult> expandedPaddedSizes(
- getMixedValues(reshapeOp.getStaticOutputShape(),
- reshapeOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> expandedPaddedSizes =
+ reshapeOp.getMixedOutputShape();
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
- for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
- OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
- OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+ for (auto [reInd, l, h] :
+ llvm::zip_equal(reassociations, collapsedPadding.lowPad,
+ collapsedPadding.highPad)) {
if (reInd.size() == 1) {
- collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
- OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
- expandedPaddedSizes[reInd[0]] = paddedSize;
}
- newLow.push_back(l);
- newHigh.push_back(h);
}
RankedTensorType collapsedPaddedType =
- paddedType.clone(collapsedPaddedShape);
+ padOp.getType().clone(collapsedPadding.paddedShape);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+class FoldReshapeWithProducerPadOpByCollapsing
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+public:
+ FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(padOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+ RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
+
+ Location loc = reshapeOp->getLoc();
+ auto newCollapseOp = tensor::CollapseShapeOp::create(
+ rewriter, loc, padOp.getSource(), reassociations);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(reshapeOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21c35526..923bb2ca9c28a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -639,6 +639,45 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
// CHECK: return %[[EXPAND]]
+// -----
+
+func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]]
+ : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+ return %collapsed : tensor<8x12x17x336x14xi32>
+}
+// CHECK: func @collapse_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>,
+ %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]]
+ : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %collapsed : tensor<?x?x?x?xf32>
+}
+// CHECK: func @collapse_shape_with_producer_pad_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+// CHECK: return %[[PAD]]
+
// -----
// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..f6572674d10e2 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -863,6 +863,47 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// -----
+func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+ %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14]
+ : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+ return %expanded : tensor<8x3x4x17x6x7x8x14xi32>
+}
+// CHECK: func @expand_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\...
[truncated]
|
🐧 Linux x64 Test Results
|
IanWood1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
770eaa3 to
d1b6dac
Compare
Signed-off-by: Max Dawkins <[email protected]>
d1b6dac to
3594ff0
Compare
The existing
FoldPadWithProducerReshapeOpByExpansionandFoldPadWithProducerReshapeOpByCollapsingpatterns did not cover all reshape propagation cases, because they only consider cases where the pad op is the consumer operation. This PR adds 2 new patterns to cover the cases where the pad op is the producer operation, which completes the propagation pattern set for pad op with expand_shape and collapse_shape.Note for integration: This PR also removes the single user restriction for the
FoldPadWithProducerReshapeOpByExpansionandFoldPadWithProducerReshapeOpByCollapsingpatterns, which leaves more control to the users of the pattern. If this constraint is needed, then it should be added to the control function for these patterns.