2727#include " mlir/Interfaces/InferIntRangeInterface.h"
2828#include " mlir/Interfaces/LoopLikeInterface.h"
2929#include " mlir/Interfaces/Utils/InferIntRangeCommon.h"
30+ #include " mlir/Interfaces/ViewLikeInterface.h"
3031#include " mlir/Support/LLVM.h"
3132#include " llvm/ADT/DenseSet.h"
3233#include " llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
23522353 }
23532354}
23542355
2355- // / Verify that the offsets/sizes/strides-style access into the given tensor
2356- // / is in-bounds. Only static information is verified.
2357- static LogicalResult verifyInBoundsSlice (Operation *op,
2358- RankedTensorType tensorType,
2359- ArrayRef<int64_t > staticOffsets,
2360- ArrayRef<int64_t > staticSizes,
2361- ArrayRef<int64_t > staticStrides) {
2362- for (int64_t i = 0 , e = tensorType.getRank (); i < e; ++i) {
2363- // Nothing to verify for dynamic source dims.
2364- if (tensorType.isDynamicDim (i))
2365- continue ;
2366- // Nothing to verify if the offset is dynamic.
2367- if (ShapedType::isDynamic (staticOffsets[i]))
2368- continue ;
2369- if (staticOffsets[i] >= tensorType.getDimSize (i))
2370- return op->emitOpError (" offset " )
2371- << i << " is out-of-bounds: " << staticOffsets[i]
2372- << " >= " << tensorType.getDimSize (i);
2373- if (ShapedType::isDynamic (staticSizes[i]) ||
2374- ShapedType::isDynamic (staticStrides[i]))
2375- continue ;
2376- int64_t lastPos =
2377- staticOffsets[i] + (staticSizes[i] - 1 ) * staticStrides[i];
2378- if (lastPos >= tensorType.getDimSize (i))
2379- return op->emitOpError (" slice along dimension " )
2380- << i << " runs out-of-bounds: " << lastPos
2381- << " >= " << tensorType.getDimSize (i);
2382- }
2383- return success ();
2384- }
2385-
23862356// / Verifier for ExtractSliceOp.
23872357LogicalResult ExtractSliceOp::verify () {
23882358 RankedTensorType sourceType = getSourceType ();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
23962366
23972367 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
23982368 // to the source tensor.
2399- return verifyInBoundsSlice (getOperation (), sourceType, getStaticOffsets (),
2400- getStaticSizes (), getStaticStrides ());
2369+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice (
2370+ sourceType.getShape (), getStaticOffsets (), getStaticSizes (),
2371+ getStaticStrides (), /* generateErrorMessage=*/ true );
2372+ if (!boundsResult.isValid )
2373+ return getOperation ()->emitError (boundsResult.errorMessage );
2374+
2375+ return success ();
24012376}
24022377
24032378llvm::SmallBitVector ExtractSliceOp::getDroppedDims () {
@@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
24702445 if (!canFoldIntoConsumerOp (castOp))
24712446 return failure ();
24722447
2448+ // Pattern does not apply if the produced op would not verify.
2449+ SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice (
2450+ cast<RankedTensorType>(castOp.getSource ().getType ()).getShape (),
2451+ sliceOp.getStaticOffsets (), sliceOp.getStaticSizes (),
2452+ sliceOp.getStaticStrides ());
2453+ if (!sliceResult.isValid )
2454+ return failure ();
2455+
24732456 // Create folded extract.
24742457 Location loc = sliceOp.getLoc ();
24752458 Value newResult = rewriter.create <ExtractSliceOp>(
@@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
26342617
26352618void ExtractSliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
26362619 MLIRContext *context) {
2637- results.add <
2638- OpWithOffsetSizesAndStridesConstantArgumentFolder<
2639- ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2640- ExtractSliceOpCastFolder>(context);
2620+ results.add <OpWithOffsetSizesAndStridesConstantArgumentFolder<
2621+ ExtractSliceOp, SliceReturnTypeCanonicalizer,
2622+ SliceCanonicalizer, /* CheckInBounds= */ true >,
2623+ ExtractSliceOpCastFolder>(context);
26412624}
26422625
26432626//
@@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
27752758 return produceSliceErrorMsg (result, *this , expectedType);
27762759
27772760 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2778- // to the source tensor.
2779- return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
2780- getStaticSizes (), getStaticStrides ());
2761+ // to the destination tensor.
2762+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice (
2763+ getDestType ().getShape (), getStaticOffsets (), getStaticSizes (),
2764+ getStaticStrides (), /* generateErrorMessage=*/ true );
2765+ if (!boundsResult.isValid )
2766+ return getOperation ()->emitError (boundsResult.errorMessage );
2767+
2768+ return success ();
27812769}
27822770
27832771// / If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
28722860 failed (foldDynamicStrideList (mixedStrides)))
28732861 return failure ();
28742862
2863+ // Pattern does not apply if the produced op would not verify.
2864+ SliceBoundsVerificationResult sliceResult =
2865+ verifyInBoundsSlice (insertSliceOp.getDest ().getType ().getShape (),
2866+ mixedOffsets, mixedSizes, mixedStrides);
2867+ if (!sliceResult.isValid )
2868+ return failure ();
2869+
28752870 // Create the new op in canonical form.
28762871 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType (
28772872 insertSliceOp.getSourceType ().getRank (), insertSliceOp.getDestType (),
@@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
29692964 size = srcType.getDimSize (rankReducedIdx++);
29702965 }
29712966 }
2967+
2968+ // Pattern does not apply if the produced op would not verify.
29722969 if (verifyInsertSliceOp (srcType, dstType, insertSliceOp.getStaticOffsets (),
29732970 staticSizes, insertSliceOp.getStaticStrides ()) !=
29742971 SliceVerificationResult::Success)
29752972 return failure ();
2973+ SliceBoundsVerificationResult sliceResult =
2974+ verifyInBoundsSlice (dstType.getShape (), insertSliceOp.getMixedOffsets (),
2975+ mixedSizes, insertSliceOp.getMixedStrides ());
2976+ if (!sliceResult.isValid )
2977+ return failure ();
29762978
29772979 Operation *replacement = rewriter.create <InsertOpTy>(
29782980 insertSliceOp.getLoc (), src, dst, insertSliceOp.getMixedOffsets (),
@@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
38003802 return produceSliceErrorMsg (result, *this , expectedType);
38013803
38023804 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3803- // to the source tensor.
3804- return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
3805- getStaticSizes (), getStaticStrides ());
3805+ // to the destination tensor.
3806+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice (
3807+ getDestType ().getShape (), getStaticOffsets (), getStaticSizes (),
3808+ getStaticStrides (), /* generateErrorMessage=*/ true );
3809+ if (!boundsResult.isValid )
3810+ return getOperation ()->emitError (boundsResult.errorMessage );
3811+
3812+ return success ();
38063813}
38073814
38083815void ParallelInsertSliceOp::getCanonicalizationPatterns (
0 commit comments