Skip to content

Commit 190261d

Browse files
authored
Use slice::align_to (#135)
When checking that tensor data can be converted to a slice of `T`, we previously were checking the right things: is the length right? Is the alignment correct? Turns out there is a `std` function that does this for us: `slice::align_to`. Replacing the custom check with the `std` version should have no effect on the code other than clarity.
1 parent 9242a8c commit 190261d

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

crates/openvino/src/tensor.rs

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ impl Tensor {
116116
/// underlying pointer's alignment.
117117
pub fn get_data<T>(&self) -> Result<&[T]> {
118118
let raw_data = self.get_raw_data()?;
119-
let len = get_safe_len::<T>(raw_data);
120-
let slice = unsafe { std::slice::from_raw_parts(raw_data.as_ptr().cast::<T>(), len) };
119+
let (prefix, slice, suffix) = unsafe { raw_data.align_to::<T>() };
120+
assert!(
121+
prefix.is_empty() && suffix.is_empty(),
122+
"raw data is not aligned to `T`'s alignment"
123+
);
121124
Ok(slice)
122125
}
123126

@@ -129,27 +132,15 @@ impl Tensor {
129132
/// underlying pointer's alignment.
130133
pub fn get_data_mut<T>(&mut self) -> Result<&mut [T]> {
131134
let raw_data = self.get_raw_data_mut()?;
132-
let len = get_safe_len::<T>(raw_data);
133-
let slice =
134-
unsafe { std::slice::from_raw_parts_mut(raw_data.as_mut_ptr().cast::<T>(), len) };
135+
let (prefix, slice, suffix) = unsafe { raw_data.align_to_mut::<T>() };
136+
assert!(
137+
prefix.is_empty() && suffix.is_empty(),
138+
"raw data is not aligned to `T`'s alignment"
139+
);
135140
Ok(slice)
136141
}
137142
}
138143

139-
/// Convenience function for checking that we can cast `data` to a slice of `T`, returning the
140-
/// length of that slice.
141-
fn get_safe_len<T>(data: &[u8]) -> usize {
142-
assert!(
143-
data.len() % std::mem::size_of::<T>() == 0,
144-
"data size is not a multiple of the size of `T`"
145-
);
146-
assert!(
147-
data.as_ptr() as usize % std::mem::align_of::<T>() == 0,
148-
"raw data is not aligned to `T`'s alignment"
149-
);
150-
data.len() / std::mem::size_of::<T>()
151-
}
152-
153144
#[cfg(test)]
154145
mod tests {
155146
use super::*;
@@ -208,7 +199,7 @@ mod tests {
208199
}
209200

210201
#[test]
211-
#[should_panic(expected = "data size is not a multiple of the size of `T`")]
202+
#[should_panic(expected = "raw data is not aligned to `T`'s alignment")]
212203
fn casting_check() {
213204
openvino_sys::library::load().unwrap();
214205
let shape = Shape::new(&[10, 10, 10]).unwrap();

0 commit comments

Comments
 (0)