Skip to content

Commit 6864a8a

Browse files
committed
Do not pessimize from_vec2/3 towards rejecting ragged arrays.
1 parent a8815fd commit 6864a8a

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

src/array.rs

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,16 +1097,15 @@ impl<T: Element> PyArray<T, Ix2> {
10971097
/// ```
10981098
pub fn from_vec2<'py>(py: Python<'py>, v: &[Vec<T>]) -> Result<&'py Self, FromVecError> {
10991099
let len2 = v.first().map_or(0, |v| v.len());
1100-
for v in v {
1101-
if v.len() != len2 {
1102-
return Err(FromVecError::new(v.len(), len2));
1103-
}
1104-
}
11051100
let dims = [v.len(), len2];
1101+
// SAFETY: The result of `Self::new` is always safe to drop.
11061102
unsafe {
11071103
let array = Self::new(py, dims, false);
11081104
let mut data_ptr = array.data();
11091105
for v in v {
1106+
if v.len() != len2 {
1107+
return Err(FromVecError::new(v.len(), len2));
1108+
}
11101109
if T::IS_COPY {
11111110
ptr::copy_nonoverlapping(v.as_ptr(), data_ptr, len2);
11121111
data_ptr = data_ptr.add(len2);
@@ -1144,25 +1143,20 @@ impl<T: Element> PyArray<T, Ix3> {
11441143
/// ```
11451144
pub fn from_vec3<'py>(py: Python<'py>, v: &[Vec<Vec<T>>]) -> Result<&'py Self, FromVecError> {
11461145
let len2 = v.first().map_or(0, |v| v.len());
1147-
for v in v {
1148-
if v.len() != len2 {
1149-
return Err(FromVecError::new(v.len(), len2));
1150-
}
1151-
}
11521146
let len3 = v.first().map_or(0, |v| v.first().map_or(0, |v| v.len()));
1153-
for v in v {
1154-
for v in v {
1155-
if v.len() != len3 {
1156-
return Err(FromVecError::new(v.len(), len3));
1157-
}
1158-
}
1159-
}
11601147
let dims = [v.len(), len2, len3];
1148+
// SAFETY: The result of `Self::new` is always safe to drop.
11611149
unsafe {
11621150
let array = Self::new(py, dims, false);
11631151
let mut data_ptr = array.data();
11641152
for v in v {
1153+
if v.len() != len2 {
1154+
return Err(FromVecError::new(v.len(), len2));
1155+
}
11651156
for v in v {
1157+
if v.len() != len3 {
1158+
return Err(FromVecError::new(v.len(), len3));
1159+
}
11661160
if T::IS_COPY {
11671161
ptr::copy_nonoverlapping(v.as_ptr(), data_ptr, len3);
11681162
data_ptr = data_ptr.add(len3);

0 commit comments

Comments
 (0)