1919#include " mlir/Dialect/Tensor/IR/Tensor.h"
2020#include " mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
2121#include " mlir/Dialect/Utils/StaticValueUtils.h"
22+ #include " mlir/IR/BuiltinTypeInterfaces.h"
2223#include " mlir/IR/Dialect.h"
2324#include " mlir/IR/Operation.h"
2425
@@ -636,6 +637,34 @@ struct InsertOpInterface
636637 }
637638};
638639
640+ template <typename InsertOpTy>
641+ static bool insertSliceOpRequiresRead (InsertOpTy insertSliceOp,
642+ OpOperand &opOperand) {
643+ RankedTensorType destType = insertSliceOp.getDestType ();
644+
645+ // The source is always read.
646+ if (opOperand == insertSliceOp.getSourceMutable ())
647+ return true ;
648+
649+ // For the destination, it depends...
650+ assert (opOperand == insertSliceOp.getDestMutable () && " expected dest" );
651+
652+ // Dest is not read if it is entirely overwritten. E.g.:
653+ // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
654+ bool allOffsetsZero =
655+ llvm::all_of (insertSliceOp.getMixedOffsets (),
656+ [](OpFoldResult ofr) { return isConstantIntValue (ofr, 0 ); });
657+ bool sizesMatchDestSizes = llvm::all_of (
658+ llvm::enumerate (insertSliceOp.getMixedSizes ()), [&](const auto &it) {
659+ return getConstantIntValue (it.value ()) ==
660+ destType.getDimSize (it.index ());
661+ });
662+ bool allStridesOne =
663+ llvm::all_of (insertSliceOp.getMixedStrides (),
664+ [](OpFoldResult ofr) { return isConstantIntValue (ofr, 1 ); });
665+ return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
666+ }
667+
639668// / Bufferization of tensor.insert_slice. Replace with a memory copy. Under
640669// / certain circumstances, this op can also be a no-op.
641670// /
@@ -646,32 +675,8 @@ struct InsertSliceOpInterface
646675 tensor::InsertSliceOp> {
647676 bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
648677 const AnalysisState &state) const {
649- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
650- RankedTensorType destType = insertSliceOp.getDestType ();
651-
652- // The source is always read.
653- if (opOperand == insertSliceOp.getSourceMutable ())
654- return true ;
655-
656- // For the destination, it depends...
657- assert (opOperand == insertSliceOp.getDestMutable () && " expected dest" );
658-
659- // Dest is not read if it is entirely overwritten. E.g.:
660- // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
661- bool allOffsetsZero =
662- llvm::all_of (insertSliceOp.getMixedOffsets (), [](OpFoldResult ofr) {
663- return isConstantIntValue (ofr, 0 );
664- });
665- bool sizesMatchDestSizes = llvm::all_of (
666- llvm::enumerate (insertSliceOp.getMixedSizes ()), [&](const auto &it) {
667- return getConstantIntValue (it.value ()) ==
668- destType.getDimSize (it.index ());
669- });
670- bool allStridesOne =
671- llvm::all_of (insertSliceOp.getMixedStrides (), [](OpFoldResult ofr) {
672- return isConstantIntValue (ofr, 1 );
673- });
674- return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
678+ return insertSliceOpRequiresRead (cast<tensor::InsertSliceOp>(op),
679+ opOperand);
675680 }
676681
677682 LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
@@ -931,7 +936,8 @@ struct ParallelInsertSliceOpInterface
931936
932937 bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
933938 const AnalysisState &state) const {
934- return true ;
939+ return insertSliceOpRequiresRead (cast<tensor::ParallelInsertSliceOp>(op),
940+ opOperand);
935941 }
936942
937943 bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
0 commit comments