Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
}
};

/// Fold tensor.extract_slice(linalg.fill(<input>)) into linalg.fill(<input>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Add a note that this triggers only when there's only one use "by design" (with one use this is obviously beneficial, with multiple uses it is not so clear)

struct FoldFillWithTensorExtractSlice
: public OpRewritePattern<tensor::ExtractSliceOp> {
public:
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
PatternRewriter &rewriter) const override {
// See if tensor input of tensor.extract_slice op is the result of a
// linalg.fill op.
auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
if (!fillOp || !fillOp->hasOneUse())
return failure();

Value fillInput = fillOp.value();

Location loc = extractSliceOp.getLoc();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, mixedSizes, extractSliceOp.getType().getElementType());

// Replace tensor.extract_slice op with new linalg.fillOp (former's result
// type and shape).
rewriter.replaceOpWithNewOp<linalg::FillOp>(
extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput},
ValueRange{emptyOp});
return success();
}
};

/// Folds pack(fill) into a single fill op if
/// 1. The pack op does not have padding value, or
/// 2. The filled value and padding value are the same.
Expand Down Expand Up @@ -935,11 +965,12 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {

void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
FoldFillWithPack, FoldFillWithPad,
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithPack,
FoldFillWithPad, FoldFillWithTensorExtract,
FoldFillWithTensorExtractSlice,
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
FoldFillWithTranspose, FoldInsertPadIntoFill>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {

// -----

func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a negative test, e.g. linalg.fill following by tensor.insert? Provided that doesn't trigger some other canonicalization.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why we may need such a test here? To test that such a pattern wouldn't get folded?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test that such a pattern wouldn't get folded?

Indeed. IMHO, we should always try covering "interesting" corner cases (within reason). These tests effectively document the code and its design - a negative test is an example of where the fold should not work. And if anything changes, we will almost immediately know (the test will start failing),

%c0 = arith.constant 0. : f32
%0 = tensor.empty() : tensor<2x1920x66x66xf32>
%1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if this linalg.fill has multiple uses?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesnt matter. You will have two fills, one larger, and one smaller. Fills are always better I think (especially in tensor semantics).

%extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32>
return %extracted_slice : tensor<2x1920x64x66xf32>
}
// CHECK-LABEL: func.func @fold_fill_extract_slice
// CHECK: %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32>
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]]
// CHECK: return %[[FILL]]

// -----

func.func @fill_pack() -> tensor<24x32x16x16xf32> {
%dest = tensor.empty() : tensor<384x512xf32>
%cst = arith.constant 0.000000e+00 : f32
Expand Down
Loading