Skip to content

Commit 758228d

Browse files
committed
[mlir] Canonicalize tensor.extract_slice (linalg.fill)
Signed-off-by: nithinsubbiah <[email protected]>
1 parent 38caf28 commit 758228d

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
806806
}
807807
};
808808

809+
/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
810+
struct FoldFillWithTensorExtractSlice
811+
: public OpRewritePattern<tensor::ExtractSliceOp> {
812+
public:
813+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
814+
815+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp,
816+
PatternRewriter &rewriter) const override {
817+
// See if tensor input of tensor.extract_slice op is the result of a
818+
// linalg.fill op.
819+
auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
820+
if (!fillOp)
821+
return failure();
822+
823+
Value fillInput = fillOp.getInputs()[0];
824+
825+
Location loc = extractSliceOp.getLoc();
826+
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
827+
auto emptyOp = rewriter.create<tensor::EmptyOp>(
828+
loc, mixedSizes, extractSliceOp.getType().getElementType());
829+
830+
// Replace tensor.extract_slice op with new linalg.fillOp (former's result
831+
// type and shape).
832+
rewriter.replaceOpWithNewOp<linalg::FillOp>(
833+
extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput},
834+
ValueRange{emptyOp});
835+
return success();
836+
}
837+
};
838+
809839
/// Folds pack(fill) into a single fill op if
810840
/// 1. The pack op does not have padding value, or
811841
/// 2. The filled value and padding value are the same.
@@ -936,7 +966,7 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
936966
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
937967
MLIRContext *context) {
938968
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
939-
FoldFillWithPack, FoldFillWithPad,
969+
FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
940970
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
941971
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
942972
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 {
352352

353353
// -----
354354

355+
func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> {
356+
%c0 = arith.constant 0. : f32
357+
%0 = tensor.empty() : tensor<2x1920x66x66xf32>
358+
%1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32>
359+
%extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32>
360+
return %extracted_slice : tensor<2x1920x64x66xf32>
361+
}
362+
// CHECK-LABEL: func.func @fold_fill_extract_slice
363+
// CHECK: %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32>
364+
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]]
365+
// CHECK: return %[[FILL]]
366+
367+
// -----
368+
355369
func.func @fill_pack() -> tensor<24x32x16x16xf32> {
356370
%dest = tensor.empty() : tensor<384x512xf32>
357371
%cst = arith.constant 0.000000e+00 : f32

0 commit comments

Comments
 (0)