Skip to content

Commit e0edc95

Browse files
ritchie46alamb
authored andcommitted
ARROW-12398: [Rust] remove redundant bound check in iterators
This PR removes the bound checks as discussed in #9994. Furthermore I added `unsafe` versions of the `value` method to `PrimitiveArray` and `BooleanArray`. The `safe` marked methods are actually `unsafe`. This way we can slowly transition to explicitly using the `unsafe` variant and later make the "safe" one truly safe. For the time being I also added a `debug_assert` bounds check in those "safe" methods that are `unsafe`. That way we at least get a panic in debug mode instead of UB in safe code. Closes #10046 from ritchie46/iterator_bounds Authored-by: Ritchie Vink <[email protected]> Signed-off-by: Andrew Lamb <[email protected]>
1 parent d05fc30 commit e0edc95

File tree

3 files changed

+83
-14
lines changed

3 files changed

+83
-14
lines changed

rust/arrow/src/array/array_boolean.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,21 @@ impl BooleanArray {
6767
&self.data.buffers()[0]
6868
}
6969

70+
/// Returns the boolean value at index `i`.
71+
///
72+
/// # Safety
73+
/// This doesn't check bounds, the caller must ensure that index < self.len()
74+
pub unsafe fn value_unchecked(&self, i: usize) -> bool {
75+
let offset = i + self.offset();
76+
bit_util::get_bit_raw(self.raw_values.as_ptr(), offset)
77+
}
78+
7079
/// Returns the boolean value at index `i`.
7180
///
7281
/// Note this doesn't do any bound checking, for performance reason.
7382
pub fn value(&self, i: usize) -> bool {
74-
let offset = i + self.offset();
75-
unsafe { bit_util::get_bit_raw(self.raw_values.as_ptr(), offset) }
83+
debug_assert!(i < self.len());
84+
unsafe { self.value_unchecked(i) }
7685
}
7786
}
7887

rust/arrow/src/array/array_primitive.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,24 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
8888
PrimitiveBuilder::<T>::new(capacity)
8989
}
9090

91+
/// Returns the primitive value at index `i`.
92+
///
93+
/// # Safety
94+
///
95+
/// caller must ensure that the passed in offset is less than the array len()
96+
pub unsafe fn value_unchecked(&self, i: usize) -> T::Native {
97+
let offset = i + self.offset();
98+
*self.raw_values.as_ptr().add(offset)
99+
}
100+
91101
/// Returns the primitive value at index `i`.
92102
///
93103
/// Note this doesn't do any bound checking, for performance reason.
94104
/// # Safety
95105
/// caller must ensure that the passed in offset is less than the array len()
96106
pub fn value(&self, i: usize) -> T::Native {
97-
let offset = i + self.offset();
98-
unsafe { *self.raw_values.as_ptr().add(offset) }
107+
debug_assert!(i < self.len());
108+
unsafe { self.value_unchecked(i) }
99109
}
100110

101111
/// Creates a PrimitiveArray based on an iterator of values without nulls

rust/arrow/src/array/iterator.rs

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ impl<'a, T: ArrowPrimitiveType> std::iter::Iterator for PrimitiveIter<'a, T> {
5656
} else {
5757
let old = self.current;
5858
self.current += 1;
59-
Some(Some(self.array.value(old)))
59+
// Safety:
60+
// we just checked bounds in `self.current_end == self.current`
61+
// this is safe on the premise that this struct is initialized with
62+
// current = array.len()
63+
// and that current_end is ever only decremented
64+
unsafe { Some(Some(self.array.value_unchecked(old))) }
6065
}
6166
}
6267

@@ -77,7 +82,12 @@ impl<'a, T: ArrowPrimitiveType> std::iter::DoubleEndedIterator for PrimitiveIter
7782
Some(if self.array.is_null(self.current_end) {
7883
None
7984
} else {
80-
Some(self.array.value(self.current_end))
85+
// Safety:
86+
// we just checked bounds in `self.current_end == self.current`
87+
// this is safe on the premise that this struct is initialized with
88+
// current = array.len()
89+
// and that current_end is ever only decremented
90+
unsafe { Some(self.array.value_unchecked(self.current_end)) }
8191
})
8292
}
8393
}
@@ -118,7 +128,12 @@ impl<'a> std::iter::Iterator for BooleanIter<'a> {
118128
} else {
119129
let old = self.current;
120130
self.current += 1;
121-
Some(Some(self.array.value(old)))
131+
// Safety:
132+
// we just checked bounds in `self.current_end == self.current`
133+
// this is safe on the premise that this struct is initialized with
134+
// current = array.len()
135+
// and that current_end is ever only decremented
136+
unsafe { Some(Some(self.array.value_unchecked(old))) }
122137
}
123138
}
124139

@@ -139,7 +154,12 @@ impl<'a> std::iter::DoubleEndedIterator for BooleanIter<'a> {
139154
Some(if self.array.is_null(self.current_end) {
140155
None
141156
} else {
142-
Some(self.array.value(self.current_end))
157+
// Safety:
158+
// we just checked bounds in `self.current_end == self.current`
159+
// this is safe on the premise that this struct is initialized with
160+
// current = array.len()
161+
// and that current_end is ever only decremented
162+
unsafe { Some(self.array.value_unchecked(self.current_end)) }
143163
})
144164
}
145165
}
@@ -182,7 +202,12 @@ impl<'a, T: StringOffsetSizeTrait> std::iter::Iterator for GenericStringIter<'a,
182202
Some(None)
183203
} else {
184204
self.current += 1;
185-
Some(Some(self.array.value(i)))
205+
// Safety:
206+
// we just checked bounds in `self.current_end == self.current`
207+
// this is safe on the premise that this struct is initialized with
208+
// current = array.len()
209+
// and that current_end is ever only decremented
210+
unsafe { Some(Some(self.array.value_unchecked(i))) }
186211
}
187212
}
188213

@@ -205,7 +230,12 @@ impl<'a, T: StringOffsetSizeTrait> std::iter::DoubleEndedIterator
205230
Some(if self.array.is_null(self.current_end) {
206231
None
207232
} else {
208-
Some(self.array.value(self.current_end))
233+
// Safety:
234+
// we just checked bounds in `self.current_end == self.current`
235+
// this is safe on the premise that this struct is initialized with
236+
// current = array.len()
237+
// and that current_end is ever only decremented
238+
unsafe { Some(self.array.value_unchecked(self.current_end)) }
209239
})
210240
}
211241
}
@@ -251,7 +281,12 @@ impl<'a, T: BinaryOffsetSizeTrait> std::iter::Iterator for GenericBinaryIter<'a,
251281
Some(None)
252282
} else {
253283
self.current += 1;
254-
Some(Some(self.array.value(i)))
284+
// Safety:
285+
// we just checked bounds in `self.current_end == self.current`
286+
// this is safe on the premise that this struct is initialized with
287+
// current = array.len()
288+
// and that current_end is ever only decremented
289+
unsafe { Some(Some(self.array.value_unchecked(i))) }
255290
}
256291
}
257292

@@ -274,7 +309,12 @@ impl<'a, T: BinaryOffsetSizeTrait> std::iter::DoubleEndedIterator
274309
Some(if self.array.is_null(self.current_end) {
275310
None
276311
} else {
277-
Some(self.array.value(self.current_end))
312+
// Safety:
313+
// we just checked bounds in `self.current_end == self.current`
314+
// this is safe on the premise that this struct is initialized with
315+
// current = array.len()
316+
// and that current_end is ever only decremented
317+
unsafe { Some(self.array.value_unchecked(self.current_end)) }
278318
})
279319
}
280320
}
@@ -318,7 +358,12 @@ impl<'a, S: OffsetSizeTrait> std::iter::Iterator for GenericListArrayIter<'a, S>
318358
Some(None)
319359
} else {
320360
self.current += 1;
321-
Some(Some(self.array.value(i)))
361+
// Safety:
362+
// we just checked bounds in `self.current_end == self.current`
363+
// this is safe on the premise that this struct is initialized with
364+
// current = array.len()
365+
// and that current_end is ever only decremented
366+
unsafe { Some(Some(self.array.value_unchecked(i))) }
322367
}
323368
}
324369

@@ -341,7 +386,12 @@ impl<'a, S: OffsetSizeTrait> std::iter::DoubleEndedIterator
341386
Some(if self.array.is_null(self.current_end) {
342387
None
343388
} else {
344-
Some(self.array.value(self.current_end))
389+
// Safety:
390+
// we just checked bounds in `self.current_end == self.current`
391+
// this is safe on the premise that this struct is initialized with
392+
// current = array.len()
393+
// and that current_end is ever only decremented
394+
unsafe { Some(self.array.value_unchecked(self.current_end)) }
345395
})
346396
}
347397
}

0 commit comments

Comments
 (0)