diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 730c478c2883e..2f2b6fed2add4 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern { } }; +/// Fold tensor.extract_slice(linalg.fill()) into linalg.fill() +struct FoldFillWithTensorExtractSlice + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const override { + // See if tensor input of tensor.extract_slice op is the result of a + // linalg.fill op. + auto fillOp = extractSliceOp.getSource().getDefiningOp(); + if (!fillOp || !fillOp->hasOneUse()) + return failure(); + + Value fillInput = fillOp.value(); + + Location loc = extractSliceOp.getLoc(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + auto emptyOp = rewriter.create( + loc, mixedSizes, extractSliceOp.getType().getElementType()); + + // Replace tensor.extract_slice op with new linalg.fillOp (former's result + // type and shape). + rewriter.replaceOpWithNewOp( + extractSliceOp, extractSliceOp.getResultType(), ValueRange{fillInput}, + ValueRange{emptyOp}); + return success(); + } +}; + /// Folds pack(fill) into a single fill op if /// 1. The pack op does not have padding value, or /// 2. The filled value and padding value are the same. @@ -935,11 +965,12 @@ struct FoldConcatsOfFill : public OpRewritePattern { void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, FoldFillWithTensorReshape, - FoldInsertPadIntoFill, FoldFillWithTranspose>(context); + FoldFillWithTranspose, FoldInsertPadIntoFill>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 4bc2ed140da91..763cf80241fcc 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -352,6 +352,20 @@ func.func @fold_fill_extract(%arg0 : i1) -> i1 { // ----- +func.func @fold_fill_extract_slice() -> tensor<2x1920x64x66xf32> { + %c0 = arith.constant 0. : f32 + %0 = tensor.empty() : tensor<2x1920x66x66xf32> + %1 = linalg.fill ins(%c0 : f32) outs(%0 : tensor<2x1920x66x66xf32>) -> tensor<2x1920x66x66xf32> + %extracted_slice = tensor.extract_slice %1[0, 0, 1, 0] [2, 1920, 64, 66] [1, 1, 1, 1] : tensor<2x1920x66x66xf32> to tensor<2x1920x64x66xf32> + return %extracted_slice : tensor<2x1920x64x66xf32> +} +// CHECK-LABEL: func.func @fold_fill_extract_slice +// CHECK: %[[EMPTY_TENSOR:.+]] = tensor.empty() : tensor<2x1920x64x66xf32> +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[EMPTY_TENSOR]] +// CHECK: return %[[FILL]] + +// ----- + func.func @fill_pack() -> tensor<24x32x16x16xf32> { %dest = tensor.empty() : tensor<384x512xf32> %cst = arith.constant 0.000000e+00 : f32