Skip to content

Commit 52e7b53

Browse files
committed
[mlir] Check if ExtractSliceOp is the only consumer while folding into FillOp
Signed-off-by: nithinsubbiah <[email protected]>
1 parent 758228d commit 52e7b53

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
806806
}
807807
};
808808

809-
/// Fold tensor.extract_slice(linalg.fill(<input>)) into <input>
809+
/// Fold tensor.extract_slice(linalg.fill(<input>)) into linalg.fill(<input>)
810810
struct FoldFillWithTensorExtractSlice
811811
: public OpRewritePattern<tensor::ExtractSliceOp> {
812812
public:
@@ -817,10 +817,10 @@ struct FoldFillWithTensorExtractSlice
817817
// See if tensor input of tensor.extract_slice op is the result of a
818818
// linalg.fill op.
819819
auto fillOp = extractSliceOp.getSource().getDefiningOp<linalg::FillOp>();
820-
if (!fillOp)
820+
if (!fillOp || !fillOp->hasOneUse())
821821
return failure();
822822

823-
Value fillInput = fillOp.getInputs()[0];
823+
Value fillInput = fillOp.value();
824824

825825
Location loc = extractSliceOp.getLoc();
826826
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
@@ -965,11 +965,12 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
965965

966966
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
967967
MLIRContext *context) {
968-
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
969-
FoldFillWithTensorExtractSlice, FoldFillWithPack, FoldFillWithPad,
968+
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithPack,
969+
FoldFillWithPad, FoldFillWithTensorExtract,
970+
FoldFillWithTensorExtractSlice,
970971
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
971972
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
972-
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
973+
FoldFillWithTranspose, FoldInsertPadIntoFill>(context);
973974
}
974975

975976
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)