Skip to content

Commit 8de9fd9

Browse files
committed
Corrects error handling in NpyIter::new
1 parent 66158f0 commit 8de9fd9

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/error.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,15 @@ impl fmt::Display for NotContiguousError {
115115
}
116116

117117
impl_pyerr!(NotContiguousError);
118+
119+
/// Represents issues in NpyIterator instantiation.
120+
#[derive(Debug)]
121+
pub struct NpyIterInstantiationError;
122+
123+
impl fmt::Display for NpyIterInstantiationError {
124+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125+
write!(f, "Unknown error while instantiating NpyIter",)
126+
}
127+
}
128+
129+
impl_pyerr!(NpyIterInstantiationError);

src/npyiter.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::npyffi::{
55
*,
66
};
77
use crate::types::Element;
8+
use crate::error::NpyIterInstantiationError;
89
use pyo3::{prelude::*, PyNativeType};
910

1011
use std::marker::PhantomData;
@@ -100,7 +101,7 @@ impl<'py, T: Element> NpySingleIterBuilder<'py, T> {
100101
)
101102
};
102103
let py = self.array.py();
103-
NpySingleIter::new(iter_ptr, py).ok_or_else(|| PyErr::fetch(py))
104+
NpySingleIter::new(iter_ptr, py)
104105
}
105106
}
106107

@@ -114,19 +115,29 @@ pub struct NpySingleIter<'py, T> {
114115
}
115116

116117
impl<'py, T> NpySingleIter<'py, T> {
117-
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> Option<NpySingleIter<'py, T>> {
118-
let mut iterator = ptr::NonNull::new(iterator)?;
118+
fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> PyResult<NpySingleIter<'py, T>> {
119+
let mut iterator = match ptr::NonNull::new(iterator) {
120+
Some(iter) => iter,
121+
None => {
122+
return Err(NpyIterInstantiationError.into());
123+
}
124+
};
119125

120126
// TODO replace the null second arg with something correct.
121-
let iternext =
122-
unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut())? };
127+
let iternext = match unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut()) } {
128+
Some(ptr) => ptr,
129+
None => {
130+
return Err(PyErr::fetch(py));
131+
}
132+
};
123133
let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(iterator.as_mut()) };
124134

125135
if dataptr.is_null() {
126136
unsafe { PY_ARRAY_API.NpyIter_Deallocate(iterator.as_mut()) };
137+
return Err(NpyIterInstantiationError.into());
127138
}
128139

129-
Some(NpySingleIter {
140+
Ok(NpySingleIter {
130141
iterator,
131142
iternext,
132143
empty: false, // TODO: Handle empty iterators

0 commit comments

Comments
 (0)