@@ -2352,13 +2352,52 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
23522352 }
23532353}
23542354
2355+ // / Verify that the offsets/sizes/strides-style access into the given tensor
2356+ // / is in-bounds. Only static information is verified.
2357+ static LogicalResult verifyInBoundsSlice (Operation *op,
2358+ RankedTensorType tensorType,
2359+ ArrayRef<int64_t > staticOffsets,
2360+ ArrayRef<int64_t > staticSizes,
2361+ ArrayRef<int64_t > staticStrides) {
2362+ for (int64_t i = 0 , e = tensorType.getRank (); i < e; ++i) {
2363+ // Nothing to verify for dynamic source dims.
2364+ if (tensorType.isDynamicDim (i))
2365+ continue ;
2366+ // Nothing to verify if the offset is dynamic.
2367+ if (ShapedType::isDynamic (staticOffsets[i]))
2368+ continue ;
2369+ if (staticOffsets[i] >= tensorType.getDimSize (i))
2370+ return op->emitOpError (" offset " )
2371+ << i << " is out-of-bounds: " << staticOffsets[i]
2372+ << " >= " << tensorType.getDimSize (i);
2373+ if (ShapedType::isDynamic (staticSizes[i]) ||
2374+ ShapedType::isDynamic (staticStrides[i]))
2375+ continue ;
2376+ int64_t lastPos =
2377+ staticOffsets[i] + (staticSizes[i] - 1 ) * staticStrides[i];
2378+ if (lastPos >= tensorType.getDimSize (i))
2379+ return op->emitOpError (" slice along dimension " )
2380+ << i << " runs out-of-bounds: " << lastPos
2381+ << " >= " << tensorType.getDimSize (i);
2382+ }
2383+ return success ();
2384+ }
2385+
23552386// / Verifier for ExtractSliceOp.
23562387LogicalResult ExtractSliceOp::verify () {
2388+ RankedTensorType sourceType = getSourceType ();
2389+
23572390 // Verify result type against inferred type.
23582391 RankedTensorType expectedType = ExtractSliceOp::inferResultType (
2359- getSourceType () , getMixedOffsets (), getMixedSizes (), getMixedStrides ());
2392+ sourceType , getMixedOffsets (), getMixedSizes (), getMixedStrides ());
23602393 SliceVerificationResult result = isRankReducedType (expectedType, getType ());
2361- return produceSliceErrorMsg (result, *this , expectedType);
2394+ if (result != SliceVerificationResult::Success)
2395+ return produceSliceErrorMsg (result, *this , expectedType);
2396+
2397+ // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2398+ // to the source tensor.
2399+ return verifyInBoundsSlice (getOperation (), sourceType, getStaticOffsets (),
2400+ getStaticSizes (), getStaticStrides ());
23622401}
23632402
23642403llvm::SmallBitVector ExtractSliceOp::getDroppedDims () {
@@ -2729,11 +2768,18 @@ static SliceVerificationResult verifyInsertSliceOp(
27292768
27302769// / Verifier for InsertSliceOp.
27312770LogicalResult InsertSliceOp::verify () {
2771+ // Verify result type against inferred type.
27322772 RankedTensorType expectedType;
27332773 SliceVerificationResult result =
27342774 verifyInsertSliceOp (getSourceType (), getType (), getStaticOffsets (),
27352775 getStaticSizes (), getStaticStrides (), &expectedType);
2736- return produceSliceErrorMsg (result, *this , expectedType);
2776+ if (result != SliceVerificationResult::Success)
2777+ return produceSliceErrorMsg (result, *this , expectedType);
2778+
2779+ // Verify that offsets, sizes, strides do not run out-of-bounds with respect
2780+ // to the source tensor.
2781+ return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
2782+ getStaticSizes (), getStaticStrides ());
27372783}
27382784
27392785// / If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -3747,11 +3793,18 @@ LogicalResult ParallelInsertSliceOp::verify() {
37473793 return this ->emitError (" expected ParallelCombiningOpInterface parent, got:" )
37483794 << *(getOperation ()->getParentOp ());
37493795
3796+ // Verify result type against inferred type.
37503797 RankedTensorType expectedType;
37513798 SliceVerificationResult result =
37523799 verifyInsertSliceOp (getSourceType (), getDestType (), getStaticOffsets (),
37533800 getStaticSizes (), getStaticStrides (), &expectedType);
3754- return produceSliceErrorMsg (result, *this , expectedType);
3801+ if (result != SliceVerificationResult::Success)
3802+ return produceSliceErrorMsg (result, *this , expectedType);
3803+
3804+ // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3805+ // to the source tensor.
3806+ return verifyInBoundsSlice (getOperation (), getDestType (), getStaticOffsets (),
3807+ getStaticSizes (), getStaticStrides ());
37553808}
37563809
37573810void ParallelInsertSliceOp::getCanonicalizationPatterns (
0 commit comments