Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/LogicalResult.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think these includes are needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like these are still here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, done.

#include <optional>
#include <utility>

Expand Down Expand Up @@ -1100,6 +1102,267 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};

bool isZero(OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
return intAttr.getInt() == 0;
}
if (auto val = dyn_cast<Value>(value)) {
if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
return attr.getInt() == 0;
}
}
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use isConstantIntValue(value, 0) for this:

bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
return getConstantIntValue(ofr) == value;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
/// by bubbling the expand_shape before the pad.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, "folding" would mean that tensor.expand_shape disappears (i.e. is folded away), but that's not what is happening here, is it? This is merely "bubbling up".

Please update the description accordingly and add some example IR before and after. As an example: https://github.com/banach-space/llvm-project/blob/7d35eb58959c0ab398a9739f38bfb9754c5ba5e5/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp#L305-L317

struct FoldReshapeWithProducerPadOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {

FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();

if (!padOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(expandOp,
"fusion blocked by control function");
}

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue) {
return rewriter.notifyMatchFailure(
expandOp, "cannot fold with non-constant padding value");
}

SmallVector<ReassociationIndices> reassociations =
expandOp.getReassociationIndices();
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();

for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[idx];
OpFoldResult h = high[idx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use llvm::zip_equal

Suggested change
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[idx];
OpFoldResult h = high[idx];
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use rewriter.notifyMatchFailure() like above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(low[idx]);
newHigh.push_back(high[idx]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can use append

Suggested change
for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(low[idx]);
newHigh.push_back(high[idx]);
}
newLow.append(reInd.size(), low[idx]);
newHigh.append(reInd.size(), high[idx]);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

Location loc = expandOp.getLoc();
auto finalType = cast<RankedTensorType>(expandOp.getType());
ArrayRef<int64_t> finalShape = finalType.getShape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expandOp.getResultType() will return a RankedTensorType, so you can do:

Suggested change
auto finalType = cast<RankedTensorType>(expandOp.getType());
ArrayRef<int64_t> finalShape = finalType.getShape();
ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape();


SmallVector<OpFoldResult> expandedShape;
for (int64_t dimSize : finalShape) {
if (dimSize == ShapedType::kDynamic) {
expandedShape.push_back(OpFoldResult{});
} else {
expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you're looking for is expandOp.getMixedOutputShape().

Suggested change
SmallVector<OpFoldResult> expandedShape;
for (int64_t dimSize : finalShape) {
if (dimSize == ShapedType::kDynamic) {
expandedShape.push_back(OpFoldResult{});
} else {
expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
}
}
SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/Users/ita/src/llvm-project/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp:1257:59: error: no member named 'getMixedOutputShape' in 'mlir::tensor::CollapseShapeOp'
 1257 |     SmallVector<OpFoldResult> collapsedShape = collapseOp.getMixedOutputShape();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make separate PR for this support. I think CollapseOp may need getMixedOutputShape too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CollapseShape doesn't need to carry the output shape, so it won't have it, but you can use it for the expand_shape pattern.


for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[inDimIdx];
OpFoldResult h = high[inDimIdx];

if (!isZero(l) || !isZero(h)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: save a level of nesting with early continue, i.e.:

if (isZero(l) && isZero(h)) {
  continue;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better not to apply this change, code is already simpler with your suggests and logic is simpler?

auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
int64_t originalSize = srcType.getDimSize(inDimIdx);

OpFoldResult originalSizeOFR;
if (originalSize == ShapedType::kDynamic) {
Value orgSizeVal =
rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
originalSizeOFR = orgSizeVal;
} else {
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you're looking for is tensor::getMixedSize:

Suggested change
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
int64_t originalSize = srcType.getDimSize(inDimIdx);
OpFoldResult originalSizeOFR;
if (originalSize == ShapedType::kDynamic) {
Value orgSizeVal =
rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
originalSizeOFR = orgSizeVal;
} else {
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
}
OpFoldResult originalSize = tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to still be unchanged too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when I apply this, I see test failure, but I'm currently trying to apply this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, that was my problem. solved!


for (auto outDimIdx : reInd) {
expandedShape[outDimIdx] = originalSizeOFR;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that the reInd should have a size of 1 from the previous matching, but I think the logic is more clear if you add an assert here that reInd.size() == 1, and then just do expandedShape[reInd[0]] = originalSizeOFR;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
if (dimSize == ShapedType::kDynamic &&
!isa<Value>(expandedShape[outDimIdx]) &&
!isa<Attribute>(expandedShape[outDimIdx])) {
Value actualSize =
rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
expandedShape[outDimIdx] = actualSize;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was necessary because some of the expandedShape were null right? I'm pretty sure this shouldn't be necessary if you use getMixedOutputShape as per my above comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did I understand your comment correctly?

this can be reduced

for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
      if (dimSize == ShapedType::kDynamic &&
          !isa<Value>(expandedShape[outDimIdx]) &&
          !isa<Attribute>(expandedShape[outDimIdx])) {
        expandedShape[outDimIdx] =
            tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
      }
    }


SmallVector<int64_t> staticExpandedShape;
for (OpFoldResult dim : expandedShape) {
if (auto attr = dyn_cast<Attribute>(dim)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
staticExpandedShape.push_back(intAttr.getInt());
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use decomposeMixedValues here:

Suggested change
SmallVector<int64_t> staticExpandedShape;
for (OpFoldResult dim : expandedShape) {
if (auto attr = dyn_cast<Attribute>(dim)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
staticExpandedShape.push_back(intAttr.getInt());
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
} else {
staticExpandedShape.push_back(ShapedType::kDynamic);
}
}
SmallVector<int64_t> staticExpandedShape;
std::tie(staticExpandedShape, std::ignore) = decomposeMixedValues(expandedShape);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also want to pass the mixed output shape here to use the correct builder:

Suggested change
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations);
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(staticExpandedShape,
padOp.getSource().getType().getElementType()),
padOp.getSource(), reassociations, expandedShape);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


auto newPadOp = rewriter.create<tensor::PadOp>(
loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOp(expandOp, newPadOp.getResult());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use rewriter.replaceOpWithNewOp<tensor::PadOp>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
/// by bubbling the collapse_shape before the pad.
struct FoldReshapeWithProducerPadOpByCollapsing
: public OpRewritePattern<tensor::CollapseShapeOp> {

FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}

LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) const override {
tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>();

if (!padOp)
return failure();

if (!padOp->hasOneUse())
return failure();

if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(collapseOp,
"fusion blocked by control function");
}

Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue) {
return rewriter.notifyMatchFailure(
collapseOp, "cannot fold with non-constant padding value");
}

SmallVector<ReassociationIndices> reassociations =
collapseOp.getReassociationIndices();
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();

for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() > 1) {
for (auto dimIdx : reInd) {
if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) {
return failure();
}
}
}
}

SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
newLow.push_back(low[reInd[0]]);
newHigh.push_back(high[reInd[0]]);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: combine this loop with the loop above.

Location loc = collapseOp.getLoc();
auto resultType = collapseOp.getResultType();

auto finalType = cast<RankedTensorType>(collapseOp.getType());
ArrayRef<int64_t> finalShape = finalType.getShape();

SmallVector<OpFoldResult> collapsedShape;
for (int64_t dimSize : finalShape) {
if (dimSize == ShapedType::kDynamic) {
collapsedShape.push_back(OpFoldResult{});
} else {
collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need the mixed sizes for the collapsed shape. It is only used for getting the new type, so you can just collect static sizes instead (SmallVector<int64_t>).


for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[reInd[0]];
OpFoldResult h = high[reInd[0]];

if (!isZero(l) || !isZero(h)) {
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
int64_t originalSize = srcType.getDimSize(reInd[0]);

OpFoldResult originalSizeOFR;
if (originalSize == ShapedType::kDynamic) {
Value orgSizeVal =
rewriter.create<tensor::DimOp>(loc, padOp.getSource(), reInd[0]);
originalSizeOFR = orgSizeVal;
} else {
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
}
collapsedShape[inDimIdx] = originalSizeOFR;
}
}

SmallVector<int64_t> staticCollapsedShape;
for (OpFoldResult dim : collapsedShape) {
if (auto attr = dyn_cast<Attribute>(dim)) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
staticCollapsedShape.push_back(intAttr.getInt());
} else {
staticCollapsedShape.push_back(ShapedType::kDynamic);
}
} else {
staticCollapsedShape.push_back(ShapedType::kDynamic);
}
}

auto newCollapseType = RankedTensorType::get(
staticCollapsedShape, padOp.getSource().getType().getElementType());
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
loc, newCollapseType, padOp.getSource(), reassociations);

auto newPadOp = rewriter.create<tensor::PadOp>(
loc, resultType, newCollapseOp.getResult(), newLow, newHigh,
padOp.getConstantPaddingValue(), padOp.getNofold());

rewriter.replaceOp(collapseOp, newPadOp.getResult());

return success();
}

private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
Expand Down Expand Up @@ -2235,6 +2498,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
Expand All @@ -2246,6 +2511,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);

patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
Expand Down
53 changes: 52 additions & 1 deletion mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1
%1 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel"]}
ins(%0 : tensor<?x?xf32>)
ins(%0 : tensor<?x?xf32>)
outs(%init : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%out = arith.negf %b0 : f32
Expand Down Expand Up @@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1:
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
// CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32>

// -----

func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
%padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %cst : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
return %collapsed : tensor<512x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_collapse(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index):
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
// CHECK: return %[[PADDED]] : tensor<512x258x258xf32>

// -----

func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
%padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %cst : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
%collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
: tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
return %collapsed : tensor<512x258x258xf32>
}
// CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero(
// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
// CHECK: ^bb0(
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: return %[[PADDED]]
Loading
Loading