@@ -806,6 +806,36 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
806806 }
807807};
808808
809+ // / Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
810+ struct FoldFillWithTensorExtractSlice
811+ : public OpRewritePattern<tensor::ExtractSliceOp> {
812+ public:
813+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
814+
815+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp extractSliceOp,
816+ PatternRewriter &rewriter) const override {
817+ // See if tensor input of tensor.extract_slice op is the result of a
818+ // linalg.fill op.
819+ auto fillOp = extractSliceOp.getSource ().getDefiningOp <linalg::FillOp>();
820+ if (!fillOp)
821+ return failure ();
822+
823+ Value fillInput = fillOp.getInputs ()[0 ];
824+
825+ Location loc = extractSliceOp.getLoc ();
826+ SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes ();
827+ auto emptyOp = rewriter.create <tensor::EmptyOp>(
828+ loc, mixedSizes, extractSliceOp.getType ().getElementType ());
829+
830+ // Replace tensor.extract_slice op with new linalg.fillOp (former's result
831+ // type and shape).
832+ rewriter.replaceOpWithNewOp <linalg::FillOp>(
833+ extractSliceOp, extractSliceOp.getResultType (), ValueRange{fillInput},
834+ ValueRange{emptyOp});
835+ return success ();
836+ }
837+ };
838+
809839// / Folds pack(fill) into a single fill op if
810840// / 1. The pack op does not have padding value, or
811841// / 2. The filled value and padding value are the same.
@@ -936,7 +966,7 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
936966void FillOp::getCanonicalizationPatterns (RewritePatternSet &results,
937967 MLIRContext *context) {
938968 results.add <FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
939- FoldFillWithPack, FoldFillWithPad,
969+ FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
940970 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
941971 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
942972 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
0 commit comments