From 2489f75de1b7665e1fdc4c42acb84b3680c9fa9b Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Tue, 11 Feb 2025 22:04:10 +0800 Subject: [PATCH] [mlir][vector] Fix out-of-bounds access This PR fixes an out-of-bounds bug that occurs when there are no overlap dimensions between the `sizes` and source of `vector.extract_strided_slice`, causing access to `sizes` to go out of bounds. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++-- mlir/test/Dialect/Vector/canonicalize.mlir | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 30ff2df7c38fc..f3f20b46add78 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3968,8 +3968,8 @@ class ContiguousExtractStridedSliceToExtract final // Avoid generating slices that have leading unit dimensions. The shape_cast // op that we create below would take bad generic fallback patterns // (ShapeCastOpRewritePattern). - while (sizes[numOffsets] == 1 && - numOffsets < static_cast(sizes.size()) - 1) { + while (numOffsets < static_cast(sizes.size()) - 1 && + sizes[numOffsets] == 1) { ++numOffsets; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 61e858f5f226a..90e00d6ea9bdc 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2921,10 +2921,10 @@ func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>) // ----- -// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract +// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_sizes_and_outer_source_dims_overlap // CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32> // CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32> -func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> { +func.func @contiguous_extract_strided_slices_to_extract_sizes_and_outer_source_dims_overlap(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> { %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32> %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32> return %2 : vector<4xi32> @@ -2932,6 +2932,17 @@ func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1 // ----- +// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_sizes_and_outer_source_dims_no_overlap +// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0] : vector<4xi32> from vector<8x2x4xi32> +// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32> +func.func @contiguous_extract_strided_slices_to_extract_sizes_and_outer_source_dims_no_overlap(%arg0 : vector<8x2x4xi32>) -> vector<4xi32> { + %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<8x2x4xi32> to vector<1x1x4xi32> + %2 = vector.shape_cast %1 : vector<1x1x4xi32> to vector<4xi32> + return %2 : vector<4xi32> +} + +// ----- + // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list // CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<1x4xi32> from vector<8x1x2x1x1x4xi32> // CHECK-NEXT: return %[[EXTRACT]] : vector<1x4xi32>