Skip to content

Commit 82e641b

Browse files
committed
Use PyArray_NewFromDescr, remove npy_type()
1 parent e09d48f commit 82e641b

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

src/array.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -428,14 +428,13 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
428428
ID: IntoDimension<Dim = D>,
429429
{
430430
let dims = dims.into_dimension();
431-
let ptr = PY_ARRAY_API.PyArray_New(
431+
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
432432
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
433+
T::get_dtype(py).into_ptr() as _,
433434
dims.ndim_cint(),
434435
dims.as_dims_ptr(),
435-
T::npy_type() as c_int,
436436
strides as *mut npy_intp, // strides
437437
ptr::null_mut(), // data
438-
0, // itemsize
439438
flag, // flag
440439
ptr::null_mut(), // obj
441440
);
@@ -453,14 +452,13 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
453452
ID: IntoDimension<Dim = D>,
454453
{
455454
let dims = dims.into_dimension();
456-
let ptr = PY_ARRAY_API.PyArray_New(
455+
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
457456
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
457+
T::get_dtype(py).into_ptr() as _,
458458
dims.ndim_cint(),
459459
dims.as_dims_ptr(),
460-
T::npy_type() as c_int,
461460
strides as *mut npy_intp, // strides
462461
data_ptr as *mut c_void, // data
463-
mem::size_of::<T>() as c_int, // itemsize
464462
npyffi::NPY_ARRAY_WRITEABLE, // flag
465463
ptr::null_mut(), // obj
466464
);
@@ -1193,7 +1191,7 @@ impl<T: Element + AsPrimitive<f64>> PyArray<T, Ix1> {
11931191
start.as_(),
11941192
stop.as_(),
11951193
step.as_(),
1196-
T::npy_type() as i32,
1194+
T::get_dtype(py).get_typenum(),
11971195
);
11981196
Self::from_owned_ptr(py, ptr)
11991197
}

src/dtype.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,19 @@ impl PyArrayDescr {
6363
DataType::from_typenum(self.get_typenum())
6464
}
6565

66+
/// Shortcut for creating a descriptor of 'object' type.
67+
pub fn object(py: Python) -> &Self {
68+
Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
69+
}
70+
6671
fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
6772
unsafe {
6873
let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as i32);
6974
py.from_owned_ptr(descr as _)
7075
}
7176
}
7277

73-
fn get_typenum(&self) -> std::os::raw::c_int {
78+
pub(crate) fn get_typenum(&self) -> std::os::raw::c_int {
7479
unsafe { *self.as_dtype_ptr() }.type_num
7580
}
7681
}
@@ -206,6 +211,10 @@ impl DataType {
206211
/// fn is_same_type(dtype: &PyArrayDescr) -> bool {
207212
/// dtype.get_datatype() == Some(DataType::Object)
208213
/// }
214+
///
215+
/// fn get_dtype(py: Python) -> &PyArrayDescr {
216+
/// PyArrayDescr::object(py)
217+
/// }
209218
/// }
210219
///
211220
/// Python::with_gil(|py| {
@@ -223,17 +232,8 @@ pub unsafe trait Element: Clone + Send {
223232
/// Returns if the give `dtype` is convertible to `Self` in Rust.
224233
fn is_same_type(dtype: &PyArrayDescr) -> bool;
225234

226-
/// Returns the corresponding
227-
/// [Enumerated Type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
228-
#[inline]
229-
fn npy_type() -> NPY_TYPES {
230-
Self::DATA_TYPE.into_ctype()
231-
}
232-
233235
/// Create `dtype`.
234-
fn get_dtype(py: Python) -> &PyArrayDescr {
235-
PyArrayDescr::from_npy_type(py, Self::npy_type())
236-
}
236+
fn get_dtype(py: Python) -> &PyArrayDescr;
237237
}
238238

239239
macro_rules! impl_num_element {
@@ -243,6 +243,9 @@ macro_rules! impl_num_element {
243243
fn is_same_type(dtype: &PyArrayDescr) -> bool {
244244
$(dtype.get_typenum() == NPY_TYPES::$npy_types as i32 ||)+ false
245245
}
246+
fn get_dtype(py: Python) -> &PyArrayDescr {
247+
PyArrayDescr::from_npy_type(py, DataType::$npy_dat_t.into_ctype())
248+
}
246249
}
247250
};
248251
}
@@ -287,4 +290,7 @@ unsafe impl Element for PyObject {
287290
fn is_same_type(dtype: &PyArrayDescr) -> bool {
288291
dtype.get_typenum() == NPY_TYPES::NPY_OBJECT as i32
289292
}
293+
fn get_dtype(py: Python) -> &PyArrayDescr {
294+
PyArrayDescr::object(py)
295+
}
290296
}

0 commit comments

Comments
 (0)