From 28fa4ffa7d91c776dccf8145bf1867a9db54953f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 22 Mar 2025 09:52:09 +0100 Subject: [PATCH] [mlir][tensor] Fix slice canonicalizer for out-of-bounds cases --- .../mlir/Interfaces/ViewLikeInterface.h | 38 +++++++- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 93 ++++++++++--------- mlir/lib/Interfaces/ViewLikeInterface.cpp | 58 ++++++++++++ mlir/test/Dialect/Tensor/canonicalize.mlir | 50 ++++++++++ 4 files changed, 194 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index 8f07e43f847ae..e74326dba7c80 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef staticVals, namespace mlir { +/// Result for slice bounds verification; +struct SliceBoundsVerificationResult { + /// If set to "true", the slice bounds verification was successful. + bool isValid; + /// An error message that can be printed during op verification. + std::string errorMessage; +}; + +/// Verify that the offsets/sizes/strides-style access into the given shape +/// is in-bounds. Only static values are verified. If `generateErrorMessage` +/// is set to "true", an error message is produced that can be printed by the +/// op verifier. +SliceBoundsVerificationResult +verifyInBoundsSlice(ArrayRef shape, ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides, + bool generateErrorMessage = false); +SliceBoundsVerificationResult verifyInBoundsSlice( + ArrayRef shape, ArrayRef mixedOffsets, + ArrayRef mixedSizes, ArrayRef mixedStrides, + bool generateErrorMessage = false); + /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as /// constant arguments. This pattern assumes that the op has a suitable builder /// that takes a result type, a "source" operand and mixed offsets, sizes and @@ -54,7 +76,8 @@ namespace mlir { /// returns the new result type of the op, based on the new offsets, sizes and /// strides. `CastOpFunc` is used to generate a cast op if the result type of /// the op has changed. -template +template class OpWithOffsetSizesAndStridesConstantArgumentFolder final : public OpRewritePattern { public: @@ -72,11 +95,22 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final failed(foldDynamicIndexList(mixedStrides))) return failure(); - // Create the new op in canonical form. + if (CheckInBounds) { + // Pattern does not apply if the produced op would not verify. + SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice( + cast(op.getSource().getType()).getShape(), mixedOffsets, + mixedSizes, mixedStrides); + if (!sliceResult.isValid) + return failure(); + } + + // Compute the new result type. auto resultType = ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides); if (!resultType) return failure(); + + // Create the new op in canonical form. auto newOp = rewriter.create(op.getLoc(), resultType, op.getSource(), mixedOffsets, mixedSizes, mixedStrides); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 2d5df07f8af4b..5f8493de991f3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, } } -/// Verify that the offsets/sizes/strides-style access into the given tensor -/// is in-bounds. Only static information is verified. -static LogicalResult verifyInBoundsSlice(Operation *op, - RankedTensorType tensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides) { - for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) { - // Nothing to verify for dynamic source dims. - if (tensorType.isDynamicDim(i)) - continue; - // Nothing to verify if the offset is dynamic. - if (ShapedType::isDynamic(staticOffsets[i])) - continue; - if (staticOffsets[i] >= tensorType.getDimSize(i)) - return op->emitOpError("offset ") - << i << " is out-of-bounds: " << staticOffsets[i] - << " >= " << tensorType.getDimSize(i); - if (ShapedType::isDynamic(staticSizes[i]) || - ShapedType::isDynamic(staticStrides[i])) - continue; - int64_t lastPos = - staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i]; - if (lastPos >= tensorType.getDimSize(i)) - return op->emitOpError("slice along dimension ") - << i << " runs out-of-bounds: " << lastPos - << " >= " << tensorType.getDimSize(i); - } - return success(); -} - /// Verifier for ExtractSliceOp. LogicalResult ExtractSliceOp::verify() { RankedTensorType sourceType = getSourceType(); @@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() { // Verify that offsets, sizes, strides do not run out-of-bounds with respect // to the source tensor. - return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(), - getStaticSizes(), getStaticStrides()); + SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice( + sourceType.getShape(), getStaticOffsets(), getStaticSizes(), + getStaticStrides(), /*generateErrorMessage=*/true); + if (!boundsResult.isValid) + return getOperation()->emitError(boundsResult.errorMessage); + + return success(); } llvm::SmallBitVector ExtractSliceOp::getDroppedDims() { @@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern { if (!canFoldIntoConsumerOp(castOp)) return failure(); + // Pattern does not apply if the produced op would not verify. + SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice( + cast(castOp.getSource().getType()).getShape(), + sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(), + sliceOp.getStaticStrides()); + if (!sliceResult.isValid) + return failure(); + // Create folded extract. Location loc = sliceOp.getLoc(); Value newResult = rewriter.create( @@ -2634,10 +2617,10 @@ struct SliceCanonicalizer { void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add< - OpWithOffsetSizesAndStridesConstantArgumentFolder< - ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>, - ExtractSliceOpCastFolder>(context); + results.add, + ExtractSliceOpCastFolder>(context); } // @@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() { return produceSliceErrorMsg(result, *this, expectedType); // Verify that offsets, sizes, strides do not run out-of-bounds with respect - // to the source tensor. - return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(), - getStaticSizes(), getStaticStrides()); + // to the destination tensor. + SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice( + getDestType().getShape(), getStaticOffsets(), getStaticSizes(), + getStaticStrides(), /*generateErrorMessage=*/true); + if (!boundsResult.isValid) + return getOperation()->emitError(boundsResult.errorMessage); + + return success(); } /// If we have two consecutive InsertSliceOp writing to the same slice, we @@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final failed(foldDynamicStrideList(mixedStrides))) return failure(); + // Pattern does not apply if the produced op would not verify. + SliceBoundsVerificationResult sliceResult = + verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(), + mixedOffsets, mixedSizes, mixedStrides); + if (!sliceResult.isValid) + return failure(); + // Create the new op in canonical form. auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType( insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(), @@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern { size = srcType.getDimSize(rankReducedIdx++); } } + + // Pattern does not apply if the produced op would not verify. if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(), staticSizes, insertSliceOp.getStaticStrides()) != SliceVerificationResult::Success) return failure(); + SliceBoundsVerificationResult sliceResult = + verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(), + mixedSizes, insertSliceOp.getMixedStrides()); + if (!sliceResult.isValid) + return failure(); Operation *replacement = rewriter.create( insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(), @@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() { return produceSliceErrorMsg(result, *this, expectedType); // Verify that offsets, sizes, strides do not run out-of-bounds with respect - // to the source tensor. - return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(), - getStaticSizes(), getStaticStrides()); + // to the destination tensor. + SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice( + getDestType().getShape(), getStaticOffsets(), getStaticSizes(), + getStaticStrides(), /*generateErrorMessage=*/true); + if (!boundsResult.isValid) + return getOperation()->emitError(boundsResult.errorMessage); + + return success(); } void ParallelInsertSliceOp::getCanonicalizationPatterns( diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp index 57b5cce7bb13b..70dd7b4aec88c 100644 --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, return success(); } +SliceBoundsVerificationResult mlir::verifyInBoundsSlice( + ArrayRef shape, ArrayRef staticOffsets, + ArrayRef staticSizes, ArrayRef staticStrides, + bool generateErrorMessage) { + SliceBoundsVerificationResult result; + result.isValid = true; + for (int64_t i = 0, e = shape.size(); i < e; ++i) { + // Nothing to verify for dynamic source dims. + if (ShapedType::isDynamic(shape[i])) + continue; + // Nothing to verify if the offset is dynamic. + if (ShapedType::isDynamic(staticOffsets[i])) + continue; + if (staticOffsets[i] >= shape[i]) { + result.errorMessage = + std::string("offset ") + std::to_string(i) + + " is out-of-bounds: " + std::to_string(staticOffsets[i]) + + " >= " + std::to_string(shape[i]); + result.isValid = false; + return result; + } + if (ShapedType::isDynamic(staticSizes[i]) || + ShapedType::isDynamic(staticStrides[i])) + continue; + int64_t lastPos = + staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i]; + if (lastPos >= shape[i]) { + result.errorMessage = std::string("slice along dimension ") + + std::to_string(i) + + " runs out-of-bounds: " + std::to_string(lastPos) + + " >= " + std::to_string(shape[i]); + result.isValid = false; + return result; + } + } + return result; +} + +SliceBoundsVerificationResult mlir::verifyInBoundsSlice( + ArrayRef shape, ArrayRef mixedOffsets, + ArrayRef mixedSizes, ArrayRef mixedStrides, + bool generateErrorMessage) { + auto getStaticValues = [](ArrayRef ofrs) { + SmallVector staticValues; + for (OpFoldResult ofr : ofrs) { + if (auto attr = dyn_cast(ofr)) { + staticValues.push_back(cast(attr).getInt()); + } else { + staticValues.push_back(ShapedType::kDynamic); + } + } + return staticValues; + }; + return verifyInBoundsSlice( + shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes), + getStaticValues(mixedStrides), generateErrorMessage); +} + LogicalResult mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { std::array maxRanks = op.getArrayAttrMaxRanks(); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 90cc0ca658ffb..fd96328c6033d 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1 // ----- +// CHECK-LABEL: func @out_of_bounds_extract_slice +// CHECK: tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor +func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor { + %c10 = arith.constant 10 : index + %r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @out_of_bounds_extract_slice +// CHECK: tensor.extract_slice %{{.*}}[0] [10] [1] : tensor to tensor<10xf32> +func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> { + %t2 = tensor.cast %t : tensor<5xf32> to tensor + %r = tensor.extract_slice %t2 [0][10][1] : tensor to tensor<10xf32> + return %r : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: func @out_of_bounds_insert_slice +// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32> +func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> { + %c10 = arith.constant 10 : index + %r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32> + return %r : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: func @out_of_bounds_insert_slice +// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor into tensor<10xf32> +func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> { + %src2 = tensor.cast %src : tensor<5xf32> to tensor + %r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor into tensor<10xf32> + return %r : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: func @out_of_bounds_insert_slice +// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor +func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor { + %dst2 = tensor.cast %dst : tensor<10xf32> to tensor + %r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor + return %r : tensor +} + +// ----- + // CHECK-LABEL: func @rank_reducing_insert_slice_of_cast // CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8> // CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>