Skip to content

Commit 8c59f1f

Browse files
committed
inner, dot, einsum
1 parent fdf7c0a commit 8c59f1f

File tree

4 files changed

+149
-2
lines changed

4 files changed

+149
-2
lines changed

src/lib.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub mod npyffi;
4343
pub mod npyiter;
4444
mod readonly;
4545
mod slice_box;
46+
mod sum_products;
4647

4748
pub use crate::array::{
4849
get_array_module, PyArray, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, PyArray6,
@@ -59,7 +60,8 @@ pub use crate::readonly::{
5960
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
6061
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn,
6162
};
62-
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
63+
pub use crate::sum_products::{dot, einsum_impl, inner};
64+
pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
6365

6466
/// Test readme
6567
#[doc(hidden)]
@@ -72,3 +74,16 @@ pub mod doc_test {
7274
}
7375
doc_comment!(include_str!("../README.md"), readme);
7476
}
77+
78+
#[macro_export]
79+
macro_rules! pyarray {
80+
($py: ident, $([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*) => {{
81+
$crate::IntoPyArray::into_pyarray($crate::array![$([$([$($x,)*],)*],)*], $py)
82+
}};
83+
($py: ident, $([$($x:expr),* $(,)*]),+ $(,)*) => {{
84+
$crate::IntoPyArray::into_pyarray($crate::array![$([$($x,)*],)*], $py)
85+
}};
86+
($py: ident, $($x:expr),* $(,)*) => {{
87+
$crate::IntoPyArray::into_pyarray($crate::array![$($x,)*], $py)
88+
}};
89+
}

src/npyffi/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ impl PyArrayAPI {
286286
impl_api![273; PyArray_ResultType(narrs: npy_intp, arr: *mut *mut PyArrayObject, ndtypes: npy_intp, dtypes: *mut *mut PyArray_Descr) -> *mut PyArray_Descr];
287287
impl_api![274; PyArray_CanCastArrayTo(arr: *mut PyArrayObject, to: *mut PyArray_Descr, casting: NPY_CASTING) -> npy_bool];
288288
impl_api![275; PyArray_CanCastTypeTo(from: *mut PyArray_Descr, to: *mut PyArray_Descr, casting: NPY_CASTING) -> npy_bool];
289-
impl_api![276; PyArray_EinsteinSum(subscripts: *mut c_char, nop: npy_intp, op_in: *mut *mut PyArrayObject, dtype: *mut PyArray_Descr, order: NPY_ORDER, casting: NPY_CASTING, out: *mut PyArrayObject) -> *mut PyArrayObject];
289+
impl_api![276; PyArray_EinsteinSum(subscripts: *mut c_char, nop: npy_intp, op_in: *mut *mut PyArrayObject, dtype: *mut PyArray_Descr, order: NPY_ORDER, casting: NPY_CASTING, out: *mut PyArrayObject) -> *mut PyObject];
290290
impl_api![277; PyArray_NewLikeArray(prototype: *mut PyArrayObject, order: NPY_ORDER, dtype: *mut PyArray_Descr, subok: c_int) -> *mut PyObject];
291291
impl_api![278; PyArray_GetArrayParamsFromObject(op: *mut PyObject, requested_dtype: *mut PyArray_Descr, writeable: npy_bool, out_dtype: *mut *mut PyArray_Descr, out_ndim: *mut c_int, out_dims: *mut npy_intp, out_arr: *mut *mut PyArrayObject, context: *mut PyObject) -> c_int];
292292
impl_api![279; PyArray_ConvertClipmodeSequence(object: *mut PyObject, modes: *mut NPY_CLIPMODE, n: c_int) -> c_int];

src/sum_products.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use crate::npyffi::{PyArrayObject, NPY_CASTING, NPY_ORDER};
2+
use crate::{Element, PyArray, PY_ARRAY_API};
3+
use ndarray::Dimension;
4+
use pyo3::{AsPyPointer, FromPyPointer, PyAny, PyNativeType, PyResult};
5+
6+
/// Return the inner product of two arrays.
7+
pub fn inner<'py, T, DIN1, DIN2, DOUT>(
8+
array1: &'py PyArray<T, DIN1>,
9+
array2: &'py PyArray<T, DIN2>,
10+
) -> PyResult<&'py PyArray<T, DOUT>>
11+
where
12+
DIN1: Dimension,
13+
DIN2: Dimension,
14+
DOUT: Dimension,
15+
T: Element,
16+
{
17+
let result = unsafe { PY_ARRAY_API.PyArray_InnerProduct(array1.as_ptr(), array2.as_ptr()) };
18+
let obj = unsafe { PyAny::from_owned_ptr_or_err(array1.py(), result)? };
19+
obj.extract()
20+
}
21+
22+
/// Return the dot product of two arrays.
23+
pub fn dot<'py, T, DIN1, DIN2, DOUT>(
24+
array1: &'py PyArray<T, DIN1>,
25+
array2: &'py PyArray<T, DIN2>,
26+
) -> PyResult<&'py PyArray<T, DOUT>>
27+
where
28+
DIN1: Dimension,
29+
DIN2: Dimension,
30+
DOUT: Dimension,
31+
T: Element,
32+
{
33+
let result = unsafe { PY_ARRAY_API.PyArray_MatrixProduct(array1.as_ptr(), array2.as_ptr()) };
34+
let obj = unsafe { PyAny::from_owned_ptr_or_err(array1.py(), result)? };
35+
obj.extract()
36+
}
37+
38+
pub unsafe fn einsum_impl<'py, T, DIN, DOUT>(
39+
dummy_array: &'py PyArray<T, DIN>,
40+
subscripts: &str,
41+
arrays: &[*mut PyArrayObject],
42+
) -> PyResult<&'py PyArray<T, DOUT>>
43+
where
44+
DIN: Dimension,
45+
DOUT: Dimension,
46+
T: Element,
47+
{
48+
let subscripts = std::ffi::CStr::from_bytes_with_nul(subscripts.as_bytes()).unwrap();
49+
let result = PY_ARRAY_API.PyArray_EinsteinSum(
50+
subscripts.as_ptr() as _,
51+
arrays.len() as _,
52+
arrays.as_ptr() as _,
53+
std::ptr::null_mut(),
54+
NPY_ORDER::NPY_KEEPORDER,
55+
NPY_CASTING::NPY_NO_CASTING,
56+
std::ptr::null_mut(),
57+
);
58+
let obj = PyAny::from_owned_ptr_or_err(dummy_array.py(), result)?;
59+
obj.extract()
60+
}
61+
62+
#[macro_export]
63+
macro_rules! einsum {
64+
($subscripts: literal, $first_array: ident $(,$array: ident)* $(,)*) => {{
65+
let arrays = [$first_array.as_array_ptr(), $($array.as_array_ptr(),)*];
66+
unsafe { $crate::einsum_impl($first_array, concat!($subscripts, "\0"), &arrays) }
67+
}};
68+
}

tests/sum_products.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use numpy::{array, dot, einsum, inner, pyarray, PyArray1, PyArray2};
2+
3+
#[test]
4+
fn test_dot() {
5+
pyo3::Python::with_gil(|py| {
6+
let a = pyarray![py, [1, 0], [0, 1]];
7+
let b = pyarray![py, [4, 1], [2, 2]];
8+
let c = dot(a, b).unwrap();
9+
assert_eq!(c.readonly().as_array(), array![[4, 1], [2, 2]]);
10+
let a = pyarray![py, 1, 2, 3];
11+
let err: pyo3::PyResult<&PyArray2<_>> = dot(a, b);
12+
let err = err.unwrap_err();
13+
assert!(err.to_string().contains("not aligned"), "{}", err);
14+
})
15+
}
16+
17+
#[test]
18+
fn test_inner() {
19+
pyo3::Python::with_gil(|py| {
20+
let a = pyarray![py, 1, 2, 3];
21+
let b = pyarray![py, 0, 1, 0];
22+
let c = inner(a, b).unwrap();
23+
assert_eq!(c.readonly().as_array(), ndarray::arr0(2));
24+
let a = pyarray![py, [1, 0], [0, 1]];
25+
let b = pyarray![py, [4, 1], [2, 2]];
26+
let c = inner(a, b).unwrap();
27+
assert_eq!(c.readonly().as_array(), array![[4, 2], [1, 2]]);
28+
let a = pyarray![py, 1, 2, 3];
29+
let err: pyo3::PyResult<&PyArray2<_>> = inner(a, b);
30+
let err = err.unwrap_err();
31+
assert!(err.to_string().contains("not aligned"), "{}", err);
32+
})
33+
}
34+
35+
#[test]
36+
fn test_einsum() {
37+
pyo3::Python::with_gil(|py| {
38+
let a = PyArray1::<i32>::arange(py, 0, 25, 1)
39+
.reshape([5, 5])
40+
.unwrap();
41+
let b = pyarray![py, 0, 1, 2, 3, 4];
42+
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
43+
assert_eq!(
44+
einsum!("ii", a).unwrap().readonly().as_array(),
45+
ndarray::arr0(60)
46+
);
47+
assert_eq!(
48+
einsum!("ii->i", a).unwrap().readonly().as_array(),
49+
array![0, 6, 12, 18, 24],
50+
);
51+
assert_eq!(
52+
einsum!("ij->i", a).unwrap().readonly().as_array(),
53+
array![10, 35, 60, 85, 110],
54+
);
55+
assert_eq!(
56+
einsum!("ji", c).unwrap().readonly().as_array(),
57+
array![[0, 3], [1, 4], [2, 5]],
58+
);
59+
assert_eq!(
60+
einsum!("ij,j", a, b).unwrap().readonly().as_array(),
61+
array![30, 80, 130, 180, 230],
62+
);
63+
})
64+
}

0 commit comments

Comments
 (0)