Skip to content

Commit ab71b1e

Browse files
JRRudy1davidhewitt
authored andcommitted
Added impl_py_clone macro that provides an efficient PyClone impl for Clone types.
The `impl_element_scalar` macro has been updated to invoke this new macro, and an explicit invocation was added for the types in the `datetime` module.
1 parent 437f5fa commit ab71b1e

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/datetime.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use std::marker::PhantomData;
6666
use pyo3::{sync::GILProtected, Bound, Py, Python};
6767
use rustc_hash::FxHashMap;
6868

69-
use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods};
69+
use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods, impl_py_clone};
7070
use crate::npyffi::{
7171
PyArray_DatetimeDTypeMetaData, PyDataType_C_METADATA, NPY_DATETIMEUNIT, NPY_TYPES,
7272
};
@@ -155,6 +155,8 @@ impl<U: Unit> From<Datetime<U>> for i64 {
155155
}
156156
}
157157

158+
impl_py_clone!(Datetime<U>; [U: Unit]);
159+
158160
unsafe impl<U: Unit> Element for Datetime<U> {
159161
const IS_COPY: bool = true;
160162

@@ -190,6 +192,8 @@ impl<U: Unit> From<Timedelta<U>> for i64 {
190192
}
191193
}
192194

195+
impl_py_clone!(Timedelta<U>; [U: Unit]);
196+
193197
unsafe impl<U: Unit> Element for Timedelta<U> {
194198
const IS_COPY: bool = true;
195199

src/dtype.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,35 @@ fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
796796
}
797797
}
798798

799+
// Implements `PyClone` for a type that implements `Clone`
800+
macro_rules! impl_py_clone {
801+
($ty:ty $(; [$param:ident $(: $bound:ident)?])?) => {
802+
impl <$($param$(: $bound)*)?> $crate::dtype::PyClone for $ty {
803+
#[inline]
804+
fn py_clone(&self, _py: ::pyo3::Python<'_>) -> Self {
805+
self.clone()
806+
}
807+
808+
#[inline]
809+
fn vec_from_slice(_py: ::pyo3::Python<'_>, slc: &[Self]) -> Vec<Self> {
810+
slc.to_owned()
811+
}
812+
813+
#[inline]
814+
fn array_from_view<D>(
815+
_py: ::pyo3::Python<'_>,
816+
view: ::ndarray::ArrayView<'_, Self, D>
817+
) -> ::ndarray::Array<Self, D>
818+
where
819+
D: ::ndarray::Dimension
820+
{
821+
view.to_owned()
822+
}
823+
}
824+
}
825+
}
826+
pub(crate) use impl_py_clone;
827+
799828
macro_rules! impl_element_scalar {
800829
(@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
801830
$(#[$meta])*
@@ -806,6 +835,7 @@ macro_rules! impl_element_scalar {
806835
PyArrayDescr::from_npy_type(py, $npy_type)
807836
}
808837
}
838+
impl_py_clone!($ty);
809839
};
810840
($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
811841
impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
@@ -842,6 +872,9 @@ unsafe impl Element for bf16 {
842872
}
843873
}
844874

875+
#[cfg(feature = "half")]
876+
impl_py_clone!(bf16);
877+
845878
impl_element_scalar!(Complex32 => NPY_CFLOAT,
846879
#[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
847880
impl_element_scalar!(Complex64 => NPY_CDOUBLE,

0 commit comments

Comments
 (0)