@@ -806,7 +806,7 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
806806 }
807807};
808808
809- // / Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
809+ // / Fold tensor.extract_slice(linalg.fill(<input>)) into linalg.fill( <input>)
810810struct FoldFillWithTensorExtractSlice
811811 : public OpRewritePattern<tensor::ExtractSliceOp> {
812812public:
@@ -817,10 +817,10 @@ struct FoldFillWithTensorExtractSlice
817817 // See if tensor input of tensor.extract_slice op is the result of a
818818 // linalg.fill op.
819819 auto fillOp = extractSliceOp.getSource ().getDefiningOp <linalg::FillOp>();
820- if (!fillOp)
820+ if (!fillOp || !fillOp-> hasOneUse () )
821821 return failure ();
822822
823- Value fillInput = fillOp.getInputs ()[ 0 ] ;
823+ Value fillInput = fillOp.value () ;
824824
825825 Location loc = extractSliceOp.getLoc ();
826826 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes ();
@@ -965,11 +965,12 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
965965
966966void FillOp::getCanonicalizationPatterns (RewritePatternSet &results,
967967 MLIRContext *context) {
968- results.add <FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
969- FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
968+ results.add <FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithPack,
969+ FoldFillWithPad, FoldFillWithTensorExtract,
970+ FoldFillWithTensorExtractSlice,
970971 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
971972 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
972- FoldInsertPadIntoFill, FoldFillWithTranspose >(context);
973+ FoldFillWithTranspose, FoldInsertPadIntoFill >(context);
973974}
974975
975976// ===----------------------------------------------------------------------===//
0 commit comments