1
- use crate :: array:: PyArray ;
2
- use crate :: npyffi;
3
- use crate :: npyffi:: array:: PY_ARRAY_API ;
4
- use crate :: npyffi:: objects;
5
- use crate :: npyffi:: types:: { npy_uint32, NPY_CASTING , NPY_ORDER } ;
6
- use pyo3:: prelude:: * ;
1
+ use crate :: array:: { PyArray , PyArrayDyn } ;
2
+ use crate :: npyffi:: {
3
+ array:: PY_ARRAY_API ,
4
+ types:: { NPY_CASTING , NPY_ORDER } ,
5
+ * ,
6
+ } ;
7
+ use crate :: types:: TypeNum ;
8
+ use pyo3:: { prelude:: * , PyNativeType } ;
7
9
8
10
use std:: marker:: PhantomData ;
9
11
use std:: os:: raw:: * ;
10
12
use std:: ptr;
11
13
12
- pub enum NPyIterFlag {
14
+ #[ derive( Clone , Copy , Debug , Eq , PartialEq ) ]
15
+ pub enum NpyIterFlag {
13
16
CIndex ,
14
17
FIndex ,
15
18
MultiIndex ,
@@ -24,105 +27,71 @@ pub enum NPyIterFlag {
24
27
DelayBufAlloc ,
25
28
DontNegateStrides ,
26
29
CopyIfOverlap ,
30
+ ReadWrite ,
31
+ ReadOnly ,
32
+ WriteOnly ,
27
33
}
28
34
29
- /*
30
-
31
- #define NPY_ITER_C_INDEX 0x00000001
32
- #define NPY_ITER_F_INDEX 0x00000002
33
- #define NPY_ITER_MULTI_INDEX 0x00000004
34
- #define NPY_ITER_EXTERNAL_LOOP 0x00000008
35
- #define NPY_ITER_COMMON_DTYPE 0x00000010
36
- #define NPY_ITER_REFS_OK 0x00000020
37
- #define NPY_ITER_ZEROSIZE_OK 0x00000040
38
- #define NPY_ITER_REDUCE_OK 0x00000080
39
- #define NPY_ITER_RANGED 0x00000100
40
- #define NPY_ITER_BUFFERED 0x00000200
41
- #define NPY_ITER_GROWINNER 0x00000400
42
- #define NPY_ITER_DELAY_BUFALLOC 0x00000800
43
- #define NPY_ITER_DONT_NEGATE_STRIDES 0x00001000
44
- #define NPY_ITER_COPY_IF_OVERLAP 0x00002000
45
- #define NPY_ITER_READWRITE 0x00010000
46
- #define NPY_ITER_READONLY 0x00020000
47
- #define NPY_ITER_WRITEONLY 0x00040000
48
- #define NPY_ITER_NBO 0x00080000
49
- #define NPY_ITER_ALIGNED 0x00100000
50
- #define NPY_ITER_CONTIG 0x00200000
51
- #define NPY_ITER_COPY 0x00400000
52
- #define NPY_ITER_UPDATEIFCOPY 0x00800000
53
- #define NPY_ITER_ALLOCATE 0x01000000
54
- #define NPY_ITER_NO_SUBTYPE 0x02000000
55
- #define NPY_ITER_VIRTUAL 0x04000000
56
- #define NPY_ITER_NO_BROADCAST 0x08000000
57
- #define NPY_ITER_WRITEMASKED 0x10000000
58
- #define NPY_ITER_ARRAYMASK 0x20000000
59
- #define NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE 0x40000000
60
-
61
- #define NPY_ITER_GLOBAL_FLAGS 0x0000ffff
62
- #define NPY_ITER_PER_OP_FLAGS 0xffff0000
63
-
64
- */
65
-
66
- impl NPyIterFlag {
35
+ impl NpyIterFlag {
67
36
fn to_c_enum ( & self ) -> npy_uint32 {
68
- use NPyIterFlag :: * ;
37
+ use NpyIterFlag :: * ;
69
38
match self {
70
- CIndex => 0x00000001 ,
71
- FIndex => 0x00000002 ,
72
- MultiIndex => 0x00000004 ,
73
- ExternalLoop => 0x00000008 ,
74
- CommonDtype => 0x00000010 ,
75
- RefsOk => 0x00000020 ,
76
- ZerosizeOk => 0x00000040 ,
77
- ReduceOk => 0x00000080 ,
78
- Ranged => 0x00000100 ,
79
- Buffered => 0x00000200 ,
80
- GrowInner => 0x00000400 ,
81
- DelayBufAlloc => 0x00000800 ,
82
- DontNegateStrides => 0x00001000 ,
83
- CopyIfOverlap => 0x00002000 ,
39
+ CIndex => NPY_ITER_C_INDEX ,
40
+ FIndex => NPY_ITER_C_INDEX ,
41
+ MultiIndex => NPY_ITER_MULTI_INDEX ,
42
+ ExternalLoop => NPY_ITER_EXTERNAL_LOOP ,
43
+ CommonDtype => NPY_ITER_COMMON_DTYPE ,
44
+ RefsOk => NPY_ITER_REFS_OK ,
45
+ ZerosizeOk => NPY_ITER_ZEROSIZE_OK ,
46
+ ReduceOk => NPY_ITER_REDUCE_OK ,
47
+ Ranged => NPY_ITER_RANGED ,
48
+ Buffered => NPY_ITER_BUFFERED ,
49
+ GrowInner => NPY_ITER_GROWINNER ,
50
+ DelayBufAlloc => NPY_ITER_DELAY_BUFALLOC ,
51
+ DontNegateStrides => NPY_ITER_DONT_NEGATE_STRIDES ,
52
+ CopyIfOverlap => NPY_ITER_COPY_IF_OVERLAP ,
53
+ ReadWrite => NPY_ITER_READWRITE ,
54
+ ReadOnly => NPY_ITER_READONLY ,
55
+ WriteOnly => NPY_ITER_WRITEONLY ,
84
56
}
85
57
}
86
58
}
87
59
88
60
pub struct NpyIterBuilder < ' py , T > {
89
61
flags : npy_uint32 ,
90
- array : * mut npyffi:: PyArrayObject ,
91
- py : Python < ' py > ,
92
- return_type : PhantomData < T > ,
62
+ array : & ' py PyArrayDyn < T > ,
93
63
}
94
64
95
- impl < ' py , T > NpyIterBuilder < ' py , T > {
96
- pub fn new < D > ( array : PyArray < T , D > , py : Python < ' py > ) -> NpyIterBuilder < ' py , T > {
65
+ impl < ' py , T : TypeNum > NpyIterBuilder < ' py , T > {
66
+ pub fn new < D : ndarray :: Dimension > ( array : & ' py PyArray < T , D > ) -> NpyIterBuilder < ' py , T > {
97
67
NpyIterBuilder {
98
- array : array. as_array_ptr ( ) ,
99
- py,
100
68
flags : 0 ,
101
- return_type : PhantomData ,
69
+ array : array . into_dyn ( ) ,
102
70
}
103
71
}
104
72
105
- pub fn set_iter_flags ( & mut self , flag : NPyIterFlag , value : bool ) -> & mut Self {
106
- if value {
107
- self . flags |= flag. to_c_enum ( ) ;
108
- } else {
109
- self . flags &= !flag. to_c_enum ( ) ;
110
- }
73
+ pub fn add ( mut self , flag : NpyIterFlag ) -> Self {
74
+ self . flags |= flag. to_c_enum ( ) ;
111
75
self
112
76
}
113
77
114
- pub fn finish ( self ) -> Option < NpyIterSingleArray < ' py , T > > {
78
+ pub fn remove ( mut self , flag : NpyIterFlag ) -> Self {
79
+ self . flags &= !flag. to_c_enum ( ) ;
80
+ self
81
+ }
82
+
83
+ pub fn build ( self ) -> PyResult < NpyIterSingleArray < ' py , T > > {
115
84
let iter_ptr = unsafe {
116
85
PY_ARRAY_API . NpyIter_New (
117
- self . array ,
86
+ self . array . as_array_ptr ( ) ,
118
87
self . flags ,
119
88
NPY_ORDER :: NPY_ANYORDER ,
120
89
NPY_CASTING :: NPY_SAFE_CASTING ,
121
90
ptr:: null_mut ( ) ,
122
91
)
123
92
} ;
124
-
125
- NpyIterSingleArray :: new ( iter_ptr, self . py )
93
+ let py = self . array . py ( ) ;
94
+ NpyIterSingleArray :: new ( iter_ptr, py ) . ok_or_else ( || PyErr :: fetch ( py ) )
126
95
}
127
96
}
128
97
0 commit comments