|
| 1 | +//! Types to support arrays of [ASCII][ascii] and [UCS4][ucs4] strings |
| 2 | +//! |
| 3 | +//! [ascii]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_STRING |
| 4 | +//! [ucs4]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_UNICODE |
| 5 | +
|
| 6 | +use std::cell::RefCell; |
| 7 | +use std::collections::hash_map::Entry; |
| 8 | +use std::convert::TryInto; |
| 9 | +use std::fmt; |
| 10 | +use std::mem::size_of; |
| 11 | +use std::os::raw::c_char; |
| 12 | +use std::str; |
| 13 | + |
| 14 | +use pyo3::{ |
| 15 | + ffi::{Py_UCS1, Py_UCS4}, |
| 16 | + sync::GILProtected, |
| 17 | + Py, Python, |
| 18 | +}; |
| 19 | +use rustc_hash::FxHashMap; |
| 20 | + |
| 21 | +use crate::dtype::{Element, PyArrayDescr}; |
| 22 | +use crate::npyffi::NPY_TYPES; |
| 23 | + |
| 24 | +/// A newtype wrapper around [`[u8; N]`][Py_UCS1] to handle [`byte` scalars][numpy-bytes] while satisfying coherence. |
| 25 | +/// |
| 26 | +/// Note that when creating arrays of ASCII strings without an explicit `dtype`, |
| 27 | +/// NumPy will automatically determine the smallest possible array length at runtime. |
| 28 | +/// |
| 29 | +/// For example, |
| 30 | +/// |
| 31 | +/// ```python |
| 32 | +/// array = numpy.array([b"foo", b"bar", b"foobar"]) |
| 33 | +/// ``` |
| 34 | +/// |
| 35 | +/// yields `S6` for `array.dtype`. |
| 36 | +/// |
| 37 | +/// On the Rust side however, the length `N` of `PyFixedString<N>` must always be given |
| 38 | +/// explicitly and as a compile-time constant. For this work reliably, the Python code |
| 39 | +/// should set the `dtype` explicitly, e.g. |
| 40 | +/// |
| 41 | +/// ```python |
| 42 | +/// numpy.array([b"foo", b"bar", b"foobar"], dtype='S12') |
| 43 | +/// ``` |
| 44 | +/// |
| 45 | +/// always matching `PyArray1<PyFixedString<12>>`. |
| 46 | +/// |
| 47 | +/// # Example |
| 48 | +/// |
| 49 | +/// ```rust |
| 50 | +/// # use pyo3::Python; |
| 51 | +/// use numpy::{PyArray1, PyFixedString}; |
| 52 | +/// |
| 53 | +/// # Python::with_gil(|py| { |
| 54 | +/// let array = PyArray1::<PyFixedString<3>>::from_vec(py, vec![[b'f', b'o', b'o'].into()]); |
| 55 | +/// |
| 56 | +/// assert!(array.dtype().to_string().contains("S3")); |
| 57 | +/// # }); |
| 58 | +/// ``` |
| 59 | +/// |
| 60 | +/// [numpy-bytes]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bytes_ |
| 61 | +#[repr(transparent)] |
| 62 | +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] |
| 63 | +pub struct PyFixedString<const N: usize>(pub [Py_UCS1; N]); |
| 64 | + |
| 65 | +impl<const N: usize> fmt::Display for PyFixedString<N> { |
| 66 | + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
| 67 | + fmt.write_str(str::from_utf8(&self.0).unwrap().trim_end_matches('\0')) |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +impl<const N: usize> From<[Py_UCS1; N]> for PyFixedString<N> { |
| 72 | + fn from(val: [Py_UCS1; N]) -> Self { |
| 73 | + Self(val) |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +unsafe impl<const N: usize> Element for PyFixedString<N> { |
| 78 | + const IS_COPY: bool = true; |
| 79 | + |
| 80 | + fn get_dtype(py: Python) -> &PyArrayDescr { |
| 81 | + static DTYPES: TypeDescriptors = TypeDescriptors::new(); |
| 82 | + |
| 83 | + unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_STRING, b'|' as _, size_of::<Self>()) } |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +/// A newtype wrapper around [`[PyUCS4; N]`][Py_UCS4] to handle [`str_` scalars][numpy-str] while satisfying coherence. |
| 88 | +/// |
| 89 | +/// Note that when creating arrays of Unicode strings without an explicit `dtype`, |
| 90 | +/// NumPy will automatically determine the smallest possible array length at runtime. |
| 91 | +/// |
| 92 | +/// For example, |
| 93 | +/// |
| 94 | +/// ```python |
| 95 | +/// numpy.array(["foo🐍", "bar🦀", "foobar"]) |
| 96 | +/// ``` |
| 97 | +/// |
| 98 | +/// yields `U6` for `array.dtype`. |
| 99 | +/// |
| 100 | +/// On the Rust side however, the length `N` of `PyFixedUnicode<N>` must always be given |
| 101 | +/// explicitly and as a compile-time constant. For this work reliably, the Python code |
| 102 | +/// should set the `dtype` explicitly, e.g. |
| 103 | +/// |
| 104 | +/// ```python |
| 105 | +/// numpy.array(["foo🐍", "bar🦀", "foobar"], dtype='U12') |
| 106 | +/// ``` |
| 107 | +/// |
| 108 | +/// always matching `PyArray1<PyFixedUnicode<12>>`. |
| 109 | +/// |
| 110 | +/// # Example |
| 111 | +/// |
| 112 | +/// ```rust |
| 113 | +/// # use pyo3::Python; |
| 114 | +/// use numpy::{PyArray1, PyFixedUnicode}; |
| 115 | +/// |
| 116 | +/// # Python::with_gil(|py| { |
| 117 | +/// let array = PyArray1::<PyFixedUnicode<3>>::from_vec(py, vec![[b'b' as _, b'a' as _, b'r' as _].into()]); |
| 118 | +/// |
| 119 | +/// assert!(array.dtype().to_string().contains("U3")); |
| 120 | +/// # }); |
| 121 | +/// ``` |
| 122 | +/// |
| 123 | +/// [numpy-str]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.str_ |
| 124 | +#[repr(transparent)] |
| 125 | +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] |
| 126 | +pub struct PyFixedUnicode<const N: usize>(pub [Py_UCS4; N]); |
| 127 | + |
| 128 | +impl<const N: usize> fmt::Display for PyFixedUnicode<N> { |
| 129 | + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { |
| 130 | + for character in self.0 { |
| 131 | + if character == 0 { |
| 132 | + break; |
| 133 | + } |
| 134 | + |
| 135 | + write!(fmt, "{}", char::from_u32(character).unwrap())?; |
| 136 | + } |
| 137 | + |
| 138 | + Ok(()) |
| 139 | + } |
| 140 | +} |
| 141 | + |
| 142 | +impl<const N: usize> From<[Py_UCS4; N]> for PyFixedUnicode<N> { |
| 143 | + fn from(val: [Py_UCS4; N]) -> Self { |
| 144 | + Self(val) |
| 145 | + } |
| 146 | +} |
| 147 | + |
| 148 | +unsafe impl<const N: usize> Element for PyFixedUnicode<N> { |
| 149 | + const IS_COPY: bool = true; |
| 150 | + |
| 151 | + fn get_dtype(py: Python) -> &PyArrayDescr { |
| 152 | + static DTYPES: TypeDescriptors = TypeDescriptors::new(); |
| 153 | + |
| 154 | + unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_UNICODE, b'=' as _, size_of::<Self>()) } |
| 155 | + } |
| 156 | +} |
| 157 | + |
| 158 | +struct TypeDescriptors { |
| 159 | + #[allow(clippy::type_complexity)] |
| 160 | + dtypes: GILProtected<RefCell<Option<FxHashMap<usize, Py<PyArrayDescr>>>>>, |
| 161 | +} |
| 162 | + |
| 163 | +impl TypeDescriptors { |
| 164 | + const fn new() -> Self { |
| 165 | + Self { |
| 166 | + dtypes: GILProtected::new(RefCell::new(None)), |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | + /// `npy_type` must be either `NPY_STRING` or `NPY_UNICODE` with matching `byteorder` and `size` |
| 171 | + #[allow(clippy::wrong_self_convention)] |
| 172 | + unsafe fn from_size<'py>( |
| 173 | + &'py self, |
| 174 | + py: Python<'py>, |
| 175 | + npy_type: NPY_TYPES, |
| 176 | + byteorder: c_char, |
| 177 | + size: usize, |
| 178 | + ) -> &'py PyArrayDescr { |
| 179 | + let mut dtypes = self.dtypes.get(py).borrow_mut(); |
| 180 | + |
| 181 | + let dtype = match dtypes.get_or_insert_with(Default::default).entry(size) { |
| 182 | + Entry::Occupied(entry) => entry.into_mut(), |
| 183 | + Entry::Vacant(entry) => { |
| 184 | + let dtype = PyArrayDescr::new_from_npy_type(py, npy_type); |
| 185 | + |
| 186 | + let descr = &mut *dtype.as_dtype_ptr(); |
| 187 | + descr.elsize = size.try_into().unwrap(); |
| 188 | + descr.byteorder = byteorder; |
| 189 | + |
| 190 | + entry.insert(dtype.into()) |
| 191 | + } |
| 192 | + }; |
| 193 | + |
| 194 | + dtype.clone().into_ref(py) |
| 195 | + } |
| 196 | +} |
| 197 | + |
| 198 | +#[cfg(test)] |
| 199 | +mod tests { |
| 200 | + use super::*; |
| 201 | + |
| 202 | + #[test] |
| 203 | + fn format_fixed_string() { |
| 204 | + assert_eq!( |
| 205 | + PyFixedString([b'f', b'o', b'o', 0, 0, 0]).to_string(), |
| 206 | + "foo" |
| 207 | + ); |
| 208 | + assert_eq!( |
| 209 | + PyFixedString([b'f', b'o', b'o', b'b', b'a', b'r']).to_string(), |
| 210 | + "foobar" |
| 211 | + ); |
| 212 | + } |
| 213 | + |
| 214 | + #[test] |
| 215 | + fn format_fixed_unicode() { |
| 216 | + assert_eq!( |
| 217 | + PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, 0, 0, 0]).to_string(), |
| 218 | + "foo" |
| 219 | + ); |
| 220 | + assert_eq!( |
| 221 | + PyFixedUnicode([0x1F980, 0x1F40D, 0, 0, 0, 0]).to_string(), |
| 222 | + "🦀🐍" |
| 223 | + ); |
| 224 | + assert_eq!( |
| 225 | + PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, b'b' as _, b'a' as _, b'r' as _]) |
| 226 | + .to_string(), |
| 227 | + "foobar" |
| 228 | + ); |
| 229 | + } |
| 230 | +} |
0 commit comments