|
| 1 | +use crate::array::PyArray; |
| 2 | +use crate::npyffi; |
| 3 | +use crate::npyffi::array::PY_ARRAY_API; |
| 4 | +use crate::npyffi::objects; |
| 5 | +use crate::npyffi::types::{npy_uint32, NPY_CASTING, NPY_ORDER}; |
| 6 | +use pyo3::prelude::*; |
| 7 | + |
| 8 | +use std::marker::PhantomData; |
| 9 | +use std::os::raw::*; |
| 10 | +use std::ptr; |
| 11 | + |
| 12 | +pub enum NPyIterFlag { |
| 13 | + CIndex, |
| 14 | + FIndex, |
| 15 | + MultiIndex, |
| 16 | + ExternalLoop, |
| 17 | + CommonDtype, |
| 18 | + RefsOk, |
| 19 | + ZerosizeOk, |
| 20 | + ReduceOk, |
| 21 | + Ranged, |
| 22 | + Buffered, |
| 23 | + GrowInner, |
| 24 | + DelayBufAlloc, |
| 25 | + DontNegateStrides, |
| 26 | + CopyIfOverlap, |
| 27 | +} |
| 28 | + |
| 29 | +/* |
| 30 | +
|
| 31 | +#define NPY_ITER_C_INDEX 0x00000001 |
| 32 | +#define NPY_ITER_F_INDEX 0x00000002 |
| 33 | +#define NPY_ITER_MULTI_INDEX 0x00000004 |
| 34 | +#define NPY_ITER_EXTERNAL_LOOP 0x00000008 |
| 35 | +#define NPY_ITER_COMMON_DTYPE 0x00000010 |
| 36 | +#define NPY_ITER_REFS_OK 0x00000020 |
| 37 | +#define NPY_ITER_ZEROSIZE_OK 0x00000040 |
| 38 | +#define NPY_ITER_REDUCE_OK 0x00000080 |
| 39 | +#define NPY_ITER_RANGED 0x00000100 |
| 40 | +#define NPY_ITER_BUFFERED 0x00000200 |
| 41 | +#define NPY_ITER_GROWINNER 0x00000400 |
| 42 | +#define NPY_ITER_DELAY_BUFALLOC 0x00000800 |
| 43 | +#define NPY_ITER_DONT_NEGATE_STRIDES 0x00001000 |
| 44 | +#define NPY_ITER_COPY_IF_OVERLAP 0x00002000 |
| 45 | +#define NPY_ITER_READWRITE 0x00010000 |
| 46 | +#define NPY_ITER_READONLY 0x00020000 |
| 47 | +#define NPY_ITER_WRITEONLY 0x00040000 |
| 48 | +#define NPY_ITER_NBO 0x00080000 |
| 49 | +#define NPY_ITER_ALIGNED 0x00100000 |
| 50 | +#define NPY_ITER_CONTIG 0x00200000 |
| 51 | +#define NPY_ITER_COPY 0x00400000 |
| 52 | +#define NPY_ITER_UPDATEIFCOPY 0x00800000 |
| 53 | +#define NPY_ITER_ALLOCATE 0x01000000 |
| 54 | +#define NPY_ITER_NO_SUBTYPE 0x02000000 |
| 55 | +#define NPY_ITER_VIRTUAL 0x04000000 |
| 56 | +#define NPY_ITER_NO_BROADCAST 0x08000000 |
| 57 | +#define NPY_ITER_WRITEMASKED 0x10000000 |
| 58 | +#define NPY_ITER_ARRAYMASK 0x20000000 |
| 59 | +#define NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE 0x40000000 |
| 60 | +
|
| 61 | +#define NPY_ITER_GLOBAL_FLAGS 0x0000ffff |
| 62 | +#define NPY_ITER_PER_OP_FLAGS 0xffff0000 |
| 63 | +
|
| 64 | +*/ |
| 65 | + |
| 66 | +impl NPyIterFlag { |
| 67 | + fn to_c_enum(&self) -> npy_uint32 { |
| 68 | + use NPyIterFlag::*; |
| 69 | + match self { |
| 70 | + CIndex => 0x00000001, |
| 71 | + FIndex => 0x00000002, |
| 72 | + MultiIndex => 0x00000004, |
| 73 | + ExternalLoop => 0x00000008, |
| 74 | + CommonDtype => 0x00000010, |
| 75 | + RefsOk => 0x00000020, |
| 76 | + ZerosizeOk => 0x00000040, |
| 77 | + ReduceOk => 0x00000080, |
| 78 | + Ranged => 0x00000100, |
| 79 | + Buffered => 0x00000200, |
| 80 | + GrowInner => 0x00000400, |
| 81 | + DelayBufAlloc => 0x00000800, |
| 82 | + DontNegateStrides => 0x00001000, |
| 83 | + CopyIfOverlap => 0x00002000, |
| 84 | + } |
| 85 | + } |
| 86 | +} |
| 87 | + |
| 88 | +pub struct NpyIterBuilder<'py, T> { |
| 89 | + flags: npy_uint32, |
| 90 | + array: *mut npyffi::PyArrayObject, |
| 91 | + py: Python<'py>, |
| 92 | + return_type: PhantomData<T>, |
| 93 | +} |
| 94 | + |
| 95 | +impl<'py, T> NpyIterBuilder<'py, T> { |
| 96 | + pub fn new<D>(array: PyArray<T, D>, py: Python<'py>) -> NpyIterBuilder<'py, T> { |
| 97 | + NpyIterBuilder { |
| 98 | + array: array.as_array_ptr(), |
| 99 | + py, |
| 100 | + flags: 0, |
| 101 | + return_type: PhantomData, |
| 102 | + } |
| 103 | + } |
| 104 | + |
| 105 | + pub fn set_iter_flags(&mut self, flag: NPyIterFlag, value: bool) -> &mut Self { |
| 106 | + if value { |
| 107 | + self.flags |= flag.to_c_enum(); |
| 108 | + } else { |
| 109 | + self.flags &= !flag.to_c_enum(); |
| 110 | + } |
| 111 | + self |
| 112 | + } |
| 113 | + |
| 114 | + pub fn finish(self) -> Option<NpyIterSingleArray<'py, T>> { |
| 115 | + let iter_ptr = unsafe { |
| 116 | + PY_ARRAY_API.NpyIter_New( |
| 117 | + self.array, |
| 118 | + self.flags, |
| 119 | + NPY_ORDER::NPY_ANYORDER, |
| 120 | + NPY_CASTING::NPY_SAFE_CASTING, |
| 121 | + ptr::null_mut(), |
| 122 | + ) |
| 123 | + }; |
| 124 | + |
| 125 | + NpyIterSingleArray::new(iter_ptr, self.py) |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +pub struct NpyIterSingleArray<'py, T> { |
| 130 | + iterator: ptr::NonNull<objects::NpyIter>, |
| 131 | + iternext: unsafe extern "C" fn(*mut objects::NpyIter) -> c_int, |
| 132 | + empty: bool, |
| 133 | + dataptr: *mut *mut c_char, |
| 134 | + |
| 135 | + return_type: PhantomData<T>, |
| 136 | + _py: Python<'py>, |
| 137 | +} |
| 138 | + |
| 139 | +impl<'py, T> NpyIterSingleArray<'py, T> { |
| 140 | + fn new(iterator: *mut objects::NpyIter, py: Python<'py>) -> Option<NpyIterSingleArray<'py, T>> { |
| 141 | + let mut iterator = ptr::NonNull::new(iterator)?; |
| 142 | + |
| 143 | + // TODO replace the null second arg with something correct. |
| 144 | + let iternext = |
| 145 | + unsafe { PY_ARRAY_API.NpyIter_GetIterNext(iterator.as_mut(), ptr::null_mut())? }; |
| 146 | + let dataptr = unsafe { PY_ARRAY_API.NpyIter_GetDataPtrArray(iterator.as_mut()) }; |
| 147 | + |
| 148 | + if dataptr.is_null() { |
| 149 | + unsafe { PY_ARRAY_API.NpyIter_Deallocate(iterator.as_mut()) }; |
| 150 | + } |
| 151 | + |
| 152 | + Some(NpyIterSingleArray { |
| 153 | + iterator, |
| 154 | + iternext, |
| 155 | + empty: false, // TODO: Handle empty iterators |
| 156 | + dataptr, |
| 157 | + return_type: PhantomData, |
| 158 | + _py: py, |
| 159 | + }) |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +impl<'py, T> Drop for NpyIterSingleArray<'py, T> { |
| 164 | + fn drop(&mut self) { |
| 165 | + let _success = unsafe { PY_ARRAY_API.NpyIter_Deallocate(self.iterator.as_mut()) }; |
| 166 | + // TODO: Handle _success somehow? |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +impl<'py, T: 'py> std::iter::Iterator for NpyIterSingleArray<'py, T> { |
| 171 | + type Item = &'py T; |
| 172 | + |
| 173 | + fn next(&mut self) -> Option<Self::Item> { |
| 174 | + if self.empty { |
| 175 | + None |
| 176 | + } else { |
| 177 | + let retval = Some(unsafe { &*(*self.dataptr as *mut T) }); |
| 178 | + self.empty = unsafe { (self.iternext)(self.iterator.as_mut()) } == 0; |
| 179 | + retval |
| 180 | + } |
| 181 | + } |
| 182 | +} |
0 commit comments