Skip to content

Commit 7067ff2

Browse files
committed
[mlir][vector] Fix out-of-bounds access
This PR fixes an out-of-bounds bug that occurs when there are no unit dimensions in the source of `vector.extract_strided_slice`, causing access to `sizes` to go out of bounds.
1 parent ac158aa commit 7067ff2

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3968,8 +3968,8 @@ class ContiguousExtractStridedSliceToExtract final
39683968
// Avoid generating slices that have leading unit dimensions. The shape_cast
39693969
// op that we create below would take bad generic fallback patterns
39703970
// (ShapeCastOpRewritePattern).
3971-
while (sizes[numOffsets] == 1 &&
3972-
numOffsets < static_cast<int>(sizes.size()) - 1) {
3971+
while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
3972+
sizes[numOffsets] == 1) {
39733973
++numOffsets;
39743974
}
39753975

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2932,6 +2932,17 @@ func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1
29322932

29332933
// -----
29342934

2935+
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_no_unit_dims
2936+
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0] : vector<4xi32> from vector<8x2x4xi32>
2937+
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
2938+
func.func @contiguous_extract_strided_slices_to_extract_no_unit_dims(%arg0 : vector<8x2x4xi32>) -> vector<4xi32> {
2939+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<8x2x4xi32> to vector<1x1x4xi32>
2940+
%2 = vector.shape_cast %1 : vector<1x1x4xi32> to vector<4xi32>
2941+
return %2 : vector<4xi32>
2942+
}
2943+
2944+
// -----
2945+
29352946
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
29362947
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<1x4xi32> from vector<8x1x2x1x1x4xi32>
29372948
// CHECK-NEXT: return %[[EXTRACT]] : vector<1x4xi32>

0 commit comments

Comments
 (0)