Skip to content

Commit ed39b10

Browse files
committed
Add PyFixedString and PyFixedUnicode implementors of Element to support Unicode arrays whose element length is know at compile time.
1 parent b185b6e commit ed39b10

File tree

5 files changed

+389
-5
lines changed

5 files changed

+389
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Changelog
22

33
- Unreleased
4+
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
45

56
- v0.19.0
67
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))

src/datetime.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ impl TypeDescriptors {
223223
fn from_unit<'py>(&'py self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> &'py PyArrayDescr {
224224
let mut dtypes = self.dtypes.get(py).borrow_mut();
225225

226-
match dtypes.get_or_insert_with(Default::default).entry(unit) {
227-
Entry::Occupied(entry) => entry.into_mut().clone().into_ref(py),
226+
let dtype = match dtypes.get_or_insert_with(Default::default).entry(unit) {
227+
Entry::Occupied(entry) => entry.into_mut(),
228228
Entry::Vacant(entry) => {
229229
let dtype = PyArrayDescr::new_from_npy_type(py, self.npy_type);
230230

@@ -237,9 +237,11 @@ impl TypeDescriptors {
237237
metadata.meta.num = 1;
238238
}
239239

240-
entry.insert(dtype.into()).clone().into_ref(py)
240+
entry.insert(dtype.into())
241241
}
242-
}
242+
};
243+
244+
dtype.clone().into_ref(py)
243245
}
244246
}
245247

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ mod dtype;
8282
mod error;
8383
pub mod npyffi;
8484
mod slice_container;
85+
mod strings;
8586
mod sum_products;
8687
mod untyped_array;
8788

@@ -105,6 +106,7 @@ pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
105106
pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr};
106107
pub use crate::error::{BorrowError, FromVecError, NotContiguousError};
107108
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
109+
pub use crate::strings::{PyFixedString, PyFixedUnicode};
108110
pub use crate::sum_products::{dot, einsum, inner};
109111
pub use crate::untyped_array::PyUntypedArray;
110112

src/strings.rs

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
//! Types to support arrays of [ASCII][ascii] and [UCS4][ucs4] strings
2+
//!
3+
//! [ascii]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_STRING
4+
//! [ucs4]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_UNICODE
5+
6+
use std::cell::RefCell;
7+
use std::collections::hash_map::Entry;
8+
use std::convert::TryInto;
9+
use std::fmt;
10+
use std::mem::size_of;
11+
use std::os::raw::c_char;
12+
use std::str;
13+
14+
use pyo3::{
15+
ffi::{Py_UCS1, Py_UCS4},
16+
sync::GILProtected,
17+
Py, Python,
18+
};
19+
use rustc_hash::FxHashMap;
20+
21+
use crate::dtype::{Element, PyArrayDescr};
22+
use crate::npyffi::NPY_TYPES;
23+
24+
/// A newtype wrapper around [`[u8; N]`][Py_UCS1] to handle [`byte` scalars][numpy-bytes] while satisfying coherence.
25+
///
26+
/// Note that when creating arrays of ASCII strings without an explicit `dtype`,
27+
/// NumPy will automatically determine the smallest possible array length at runtime.
28+
///
29+
/// For example,
30+
///
31+
/// ```python
32+
/// array = numpy.array([b"foo", b"bar", b"foobar"])
33+
/// ```
34+
///
35+
/// yields `S6` for `array.dtype`.
36+
///
37+
/// On the Rust side however, the length `N` of `PyFixedString<N>` must always be given
38+
/// explicitly and as a compile-time constant. For this work reliably, the Python code
39+
/// should set the `dtype` explicitly, e.g.
40+
///
41+
/// ```python
42+
/// numpy.array([b"foo", b"bar", b"foobar"], dtype='S12')
43+
/// ```
44+
///
45+
/// always matching `PyArray1<PyFixedString<12>>`.
46+
///
47+
/// # Example
48+
///
49+
/// ```rust
50+
/// # use pyo3::Python;
51+
/// use numpy::{PyArray1, PyFixedString};
52+
///
53+
/// # Python::with_gil(|py| {
54+
/// let array = PyArray1::<PyFixedString<3>>::from_vec(py, vec![[b'f', b'o', b'o'].into()]);
55+
///
56+
/// assert!(array.dtype().to_string().contains("S3"));
57+
/// # });
58+
/// ```
59+
///
60+
/// [numpy-bytes]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bytes_
61+
#[repr(transparent)]
62+
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
63+
pub struct PyFixedString<const N: usize>(pub [Py_UCS1; N]);
64+
65+
impl<const N: usize> fmt::Display for PyFixedString<N> {
66+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
67+
fmt.write_str(str::from_utf8(&self.0).unwrap().trim_end_matches('\0'))
68+
}
69+
}
70+
71+
impl<const N: usize> From<[Py_UCS1; N]> for PyFixedString<N> {
72+
fn from(val: [Py_UCS1; N]) -> Self {
73+
Self(val)
74+
}
75+
}
76+
77+
unsafe impl<const N: usize> Element for PyFixedString<N> {
78+
const IS_COPY: bool = true;
79+
80+
fn get_dtype(py: Python) -> &PyArrayDescr {
81+
static DTYPES: TypeDescriptors = TypeDescriptors::new();
82+
83+
unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_STRING, b'|' as _, size_of::<Self>()) }
84+
}
85+
}
86+
87+
/// A newtype wrapper around [`[PyUCS4; N]`][Py_UCS4] to handle [`str_` scalars][numpy-str] while satisfying coherence.
88+
///
89+
/// Note that when creating arrays of Unicode strings without an explicit `dtype`,
90+
/// NumPy will automatically determine the smallest possible array length at runtime.
91+
///
92+
/// For example,
93+
///
94+
/// ```python
95+
/// numpy.array(["foo🐍", "bar🦀", "foobar"])
96+
/// ```
97+
///
98+
/// yields `U6` for `array.dtype`.
99+
///
100+
/// On the Rust side however, the length `N` of `PyFixedUnicode<N>` must always be given
101+
/// explicitly and as a compile-time constant. For this work reliably, the Python code
102+
/// should set the `dtype` explicitly, e.g.
103+
///
104+
/// ```python
105+
/// numpy.array(["foo🐍", "bar🦀", "foobar"], dtype='U12')
106+
/// ```
107+
///
108+
/// always matching `PyArray1<PyFixedUnicode<12>>`.
109+
///
110+
/// # Example
111+
///
112+
/// ```rust
113+
/// # use pyo3::Python;
114+
/// use numpy::{PyArray1, PyFixedUnicode};
115+
///
116+
/// # Python::with_gil(|py| {
117+
/// let array = PyArray1::<PyFixedUnicode<3>>::from_vec(py, vec![[b'b' as _, b'a' as _, b'r' as _].into()]);
118+
///
119+
/// assert!(array.dtype().to_string().contains("U3"));
120+
/// # });
121+
/// ```
122+
///
123+
/// [numpy-str]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.str_
124+
#[repr(transparent)]
125+
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
126+
pub struct PyFixedUnicode<const N: usize>(pub [Py_UCS4; N]);
127+
128+
impl<const N: usize> fmt::Display for PyFixedUnicode<N> {
129+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
130+
for character in self.0 {
131+
if character == 0 {
132+
break;
133+
}
134+
135+
write!(fmt, "{}", char::from_u32(character).unwrap())?;
136+
}
137+
138+
Ok(())
139+
}
140+
}
141+
142+
impl<const N: usize> From<[Py_UCS4; N]> for PyFixedUnicode<N> {
143+
fn from(val: [Py_UCS4; N]) -> Self {
144+
Self(val)
145+
}
146+
}
147+
148+
unsafe impl<const N: usize> Element for PyFixedUnicode<N> {
149+
const IS_COPY: bool = true;
150+
151+
fn get_dtype(py: Python) -> &PyArrayDescr {
152+
static DTYPES: TypeDescriptors = TypeDescriptors::new();
153+
154+
unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_UNICODE, b'=' as _, size_of::<Self>()) }
155+
}
156+
}
157+
158+
struct TypeDescriptors {
159+
#[allow(clippy::type_complexity)]
160+
dtypes: GILProtected<RefCell<Option<FxHashMap<usize, Py<PyArrayDescr>>>>>,
161+
}
162+
163+
impl TypeDescriptors {
164+
const fn new() -> Self {
165+
Self {
166+
dtypes: GILProtected::new(RefCell::new(None)),
167+
}
168+
}
169+
170+
/// `npy_type` must be either `NPY_STRING` or `NPY_UNICODE` with matching `byteorder` and `size`
171+
#[allow(clippy::wrong_self_convention)]
172+
unsafe fn from_size<'py>(
173+
&'py self,
174+
py: Python<'py>,
175+
npy_type: NPY_TYPES,
176+
byteorder: c_char,
177+
size: usize,
178+
) -> &'py PyArrayDescr {
179+
let mut dtypes = self.dtypes.get(py).borrow_mut();
180+
181+
let dtype = match dtypes.get_or_insert_with(Default::default).entry(size) {
182+
Entry::Occupied(entry) => entry.into_mut(),
183+
Entry::Vacant(entry) => {
184+
let dtype = PyArrayDescr::new_from_npy_type(py, npy_type);
185+
186+
let descr = &mut *dtype.as_dtype_ptr();
187+
descr.elsize = size.try_into().unwrap();
188+
descr.byteorder = byteorder;
189+
190+
entry.insert(dtype.into())
191+
}
192+
};
193+
194+
dtype.clone().into_ref(py)
195+
}
196+
}
197+
198+
#[cfg(test)]
199+
mod tests {
200+
use super::*;
201+
202+
#[test]
203+
fn format_fixed_string() {
204+
assert_eq!(
205+
PyFixedString([b'f', b'o', b'o', 0, 0, 0]).to_string(),
206+
"foo"
207+
);
208+
assert_eq!(
209+
PyFixedString([b'f', b'o', b'o', b'b', b'a', b'r']).to_string(),
210+
"foobar"
211+
);
212+
}
213+
214+
#[test]
215+
fn format_fixed_unicode() {
216+
assert_eq!(
217+
PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, 0, 0, 0]).to_string(),
218+
"foo"
219+
);
220+
assert_eq!(
221+
PyFixedUnicode([0x1F980, 0x1F40D, 0, 0, 0, 0]).to_string(),
222+
"🦀🐍"
223+
);
224+
assert_eq!(
225+
PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, b'b' as _, b'a' as _, b'r' as _])
226+
.to_string(),
227+
"foobar"
228+
);
229+
}
230+
}

0 commit comments

Comments
 (0)