Skip to content

Commit c79f220

Browse files
kngwyuPTNobel
authored andcommitted
Mentoring: Refactor with PyO3 idioms and Add a test
1 parent f38bbd8 commit c79f220

File tree

3 files changed

+95
-78
lines changed

3 files changed

+95
-78
lines changed

src/npyffi/flags.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::npy_uint32;
12
use std::os::raw::c_int;
23

34
pub const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x0001;
@@ -28,3 +29,36 @@ pub const NPY_ARRAY_OUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
2829
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_UPDATEIFCOPY;
2930
pub const NPY_ARRAY_INOUT_FARRAY2: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
3031
pub const NPY_ARRAY_UPDATE_ALL: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS;
32+
33+
pub const NPY_ITER_C_INDEX: npy_uint32 = 0x00000001;
34+
pub const NPY_ITER_F_INDEX: npy_uint32 = 0x00000002;
35+
pub const NPY_ITER_MULTI_INDEX: npy_uint32 = 0x00000004;
36+
pub const NPY_ITER_EXTERNAL_LOOP: npy_uint32 = 0x00000008;
37+
pub const NPY_ITER_COMMON_DTYPE: npy_uint32 = 0x00000010;
38+
pub const NPY_ITER_REFS_OK: npy_uint32 = 0x00000020;
39+
pub const NPY_ITER_ZEROSIZE_OK: npy_uint32 = 0x00000040;
40+
pub const NPY_ITER_REDUCE_OK: npy_uint32 = 0x00000080;
41+
pub const NPY_ITER_RANGED: npy_uint32 = 0x00000100;
42+
pub const NPY_ITER_BUFFERED: npy_uint32 = 0x00000200;
43+
pub const NPY_ITER_GROWINNER: npy_uint32 = 0x00000400;
44+
pub const NPY_ITER_DELAY_BUFALLOC: npy_uint32 = 0x00000800;
45+
pub const NPY_ITER_DONT_NEGATE_STRIDES: npy_uint32 = 0x00001000;
46+
pub const NPY_ITER_COPY_IF_OVERLAP: npy_uint32 = 0x00002000;
47+
pub const NPY_ITER_READWRITE: npy_uint32 = 0x00010000;
48+
pub const NPY_ITER_READONLY: npy_uint32 = 0x00020000;
49+
pub const NPY_ITER_WRITEONLY: npy_uint32 = 0x00040000;
50+
pub const NPY_ITER_NBO: npy_uint32 = 0x00080000;
51+
pub const NPY_ITER_ALIGNED: npy_uint32 = 0x00100000;
52+
pub const NPY_ITER_CONTIG: npy_uint32 = 0x00200000;
53+
pub const NPY_ITER_COPY: npy_uint32 = 0x00400000;
54+
pub const NPY_ITER_UPDATEIFCOPY: npy_uint32 = 0x00800000;
55+
pub const NPY_ITER_ALLOCATE: npy_uint32 = 0x01000000;
56+
pub const NPY_ITER_NO_SUBTYPE: npy_uint32 = 0x02000000;
57+
pub const NPY_ITER_VIRTUAL: npy_uint32 = 0x04000000;
58+
pub const NPY_ITER_NO_BROADCAST: npy_uint32 = 0x08000000;
59+
pub const NPY_ITER_WRITEMASKED: npy_uint32 = 0x10000000;
60+
pub const NPY_ITER_ARRAYMASK: npy_uint32 = 0x20000000;
61+
pub const NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE: npy_uint32 = 0x40000000;
62+
63+
pub const NPY_ITER_GLOBAL_FLAGS: npy_uint32 = 0x0000ffff;
64+
pub const NPY_ITER_PER_OP_FLAGS: npy_uint32 = 0xffff0000;

src/npyiter.rs

Lines changed: 47 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1-
use crate::array::PyArray;
2-
use crate::npyffi;
3-
use crate::npyffi::array::PY_ARRAY_API;
4-
use crate::npyffi::objects;
5-
use crate::npyffi::types::{npy_uint32, NPY_CASTING, NPY_ORDER};
6-
use pyo3::prelude::*;
1+
use crate::array::{PyArray, PyArrayDyn};
2+
use crate::npyffi::{
3+
array::PY_ARRAY_API,
4+
types::{NPY_CASTING, NPY_ORDER},
5+
*,
6+
};
7+
use crate::types::TypeNum;
8+
use pyo3::{prelude::*, PyNativeType};
79

810
use std::marker::PhantomData;
911
use std::os::raw::*;
1012
use std::ptr;
1113

12-
pub enum NPyIterFlag {
14+
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
15+
pub enum NpyIterFlag {
1316
CIndex,
1417
FIndex,
1518
MultiIndex,
@@ -24,105 +27,71 @@ pub enum NPyIterFlag {
2427
DelayBufAlloc,
2528
DontNegateStrides,
2629
CopyIfOverlap,
30+
ReadWrite,
31+
ReadOnly,
32+
WriteOnly,
2733
}
2834

29-
/*
30-
31-
#define NPY_ITER_C_INDEX 0x00000001
32-
#define NPY_ITER_F_INDEX 0x00000002
33-
#define NPY_ITER_MULTI_INDEX 0x00000004
34-
#define NPY_ITER_EXTERNAL_LOOP 0x00000008
35-
#define NPY_ITER_COMMON_DTYPE 0x00000010
36-
#define NPY_ITER_REFS_OK 0x00000020
37-
#define NPY_ITER_ZEROSIZE_OK 0x00000040
38-
#define NPY_ITER_REDUCE_OK 0x00000080
39-
#define NPY_ITER_RANGED 0x00000100
40-
#define NPY_ITER_BUFFERED 0x00000200
41-
#define NPY_ITER_GROWINNER 0x00000400
42-
#define NPY_ITER_DELAY_BUFALLOC 0x00000800
43-
#define NPY_ITER_DONT_NEGATE_STRIDES 0x00001000
44-
#define NPY_ITER_COPY_IF_OVERLAP 0x00002000
45-
#define NPY_ITER_READWRITE 0x00010000
46-
#define NPY_ITER_READONLY 0x00020000
47-
#define NPY_ITER_WRITEONLY 0x00040000
48-
#define NPY_ITER_NBO 0x00080000
49-
#define NPY_ITER_ALIGNED 0x00100000
50-
#define NPY_ITER_CONTIG 0x00200000
51-
#define NPY_ITER_COPY 0x00400000
52-
#define NPY_ITER_UPDATEIFCOPY 0x00800000
53-
#define NPY_ITER_ALLOCATE 0x01000000
54-
#define NPY_ITER_NO_SUBTYPE 0x02000000
55-
#define NPY_ITER_VIRTUAL 0x04000000
56-
#define NPY_ITER_NO_BROADCAST 0x08000000
57-
#define NPY_ITER_WRITEMASKED 0x10000000
58-
#define NPY_ITER_ARRAYMASK 0x20000000
59-
#define NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE 0x40000000
60-
61-
#define NPY_ITER_GLOBAL_FLAGS 0x0000ffff
62-
#define NPY_ITER_PER_OP_FLAGS 0xffff0000
63-
64-
*/
65-
66-
impl NPyIterFlag {
35+
impl NpyIterFlag {
6736
fn to_c_enum(&self) -> npy_uint32 {
68-
use NPyIterFlag::*;
37+
use NpyIterFlag::*;
6938
match self {
70-
CIndex => 0x00000001,
71-
FIndex => 0x00000002,
72-
MultiIndex => 0x00000004,
73-
ExternalLoop => 0x00000008,
74-
CommonDtype => 0x00000010,
75-
RefsOk => 0x00000020,
76-
ZerosizeOk => 0x00000040,
77-
ReduceOk => 0x00000080,
78-
Ranged => 0x00000100,
79-
Buffered => 0x00000200,
80-
GrowInner => 0x00000400,
81-
DelayBufAlloc => 0x00000800,
82-
DontNegateStrides => 0x00001000,
83-
CopyIfOverlap => 0x00002000,
39+
CIndex => NPY_ITER_C_INDEX,
40+
FIndex => NPY_ITER_C_INDEX,
41+
MultiIndex => NPY_ITER_MULTI_INDEX,
42+
ExternalLoop => NPY_ITER_EXTERNAL_LOOP,
43+
CommonDtype => NPY_ITER_COMMON_DTYPE,
44+
RefsOk => NPY_ITER_REFS_OK,
45+
ZerosizeOk => NPY_ITER_ZEROSIZE_OK,
46+
ReduceOk => NPY_ITER_REDUCE_OK,
47+
Ranged => NPY_ITER_RANGED,
48+
Buffered => NPY_ITER_BUFFERED,
49+
GrowInner => NPY_ITER_GROWINNER,
50+
DelayBufAlloc => NPY_ITER_DELAY_BUFALLOC,
51+
DontNegateStrides => NPY_ITER_DONT_NEGATE_STRIDES,
52+
CopyIfOverlap => NPY_ITER_COPY_IF_OVERLAP,
53+
ReadWrite => NPY_ITER_READWRITE,
54+
ReadOnly => NPY_ITER_READONLY,
55+
WriteOnly => NPY_ITER_WRITEONLY,
8456
}
8557
}
8658
}
8759

8860
pub struct NpyIterBuilder<'py, T> {
8961
flags: npy_uint32,
90-
array: *mut npyffi::PyArrayObject,
91-
py: Python<'py>,
92-
return_type: PhantomData<T>,
62+
array: &'py PyArrayDyn<T>,
9363
}
9464

95-
impl<'py, T> NpyIterBuilder<'py, T> {
96-
pub fn new<D>(array: PyArray<T, D>, py: Python<'py>) -> NpyIterBuilder<'py, T> {
65+
impl<'py, T: TypeNum> NpyIterBuilder<'py, T> {
66+
pub fn new<D: ndarray::Dimension>(array: &'py PyArray<T, D>) -> NpyIterBuilder<'py, T> {
9767
NpyIterBuilder {
98-
array: array.as_array_ptr(),
99-
py,
10068
flags: 0,
101-
return_type: PhantomData,
69+
array: array.into_dyn(),
10270
}
10371
}
10472

105-
pub fn set_iter_flags(&mut self, flag: NPyIterFlag, value: bool) -> &mut Self {
106-
if value {
107-
self.flags |= flag.to_c_enum();
108-
} else {
109-
self.flags &= !flag.to_c_enum();
110-
}
73+
pub fn add(mut self, flag: NpyIterFlag) -> Self {
74+
self.flags |= flag.to_c_enum();
11175
self
11276
}
11377

114-
pub fn finish(self) -> Option<NpyIterSingleArray<'py, T>> {
78+
pub fn remove(mut self, flag: NpyIterFlag) -> Self {
79+
self.flags &= !flag.to_c_enum();
80+
self
81+
}
82+
83+
pub fn build(self) -> PyResult<NpyIterSingleArray<'py, T>> {
11584
let iter_ptr = unsafe {
11685
PY_ARRAY_API.NpyIter_New(
117-
self.array,
86+
self.array.as_array_ptr(),
11887
self.flags,
11988
NPY_ORDER::NPY_ANYORDER,
12089
NPY_CASTING::NPY_SAFE_CASTING,
12190
ptr::null_mut(),
12291
)
12392
};
124-
125-
NpyIterSingleArray::new(iter_ptr, self.py)
93+
let py = self.array.py();
94+
NpyIterSingleArray::new(iter_ptr, py).ok_or_else(|| PyErr::fetch(py))
12695
}
12796
}
12897

tests/iter.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use numpy::{npyiter::NpyIterFlag, *};
2+
3+
#[test]
4+
fn get_iter() {
5+
let gil = pyo3::Python::acquire_gil();
6+
let dim = (3, 5);
7+
let arr = PyArray::<f64, _>::zeros(gil.python(), dim, false);
8+
let mut iter = npyiter::NpyIterBuilder::new(arr)
9+
.add(NpyIterFlag::ReadOnly)
10+
.build()
11+
.map_err(|e| e.print(gil.python()))
12+
.unwrap();
13+
assert_eq!(*iter.next().unwrap(), 0.0);
14+
}

0 commit comments

Comments
 (0)