Skip to content

Commit eb99fe7

Browse files
committed
Add docs for inner, dot, einsum, and pyarray
1 parent dd22664 commit eb99fe7

File tree

3 files changed

+92
-24
lines changed

3 files changed

+92
-24
lines changed

src/array.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,8 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
710710

711711
impl<T: Copy + Element> PyArray<T, Ix0> {
712712
/// Get the element of zero-dimensional PyArray.
713+
///
714+
/// See [inner](../fn.inner.html) for example.
713715
pub fn item(&self) -> T {
714716
unsafe { *self.data() }
715717
}

src/lib.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ pub mod doc_test {
7575
doc_comment!(include_str!("../README.md"), readme);
7676
}
7777

78+
/// Create a [PyArray](./array/struct.PyArray.html) with one, two or three dimensions.
79+
/// This macro is backed by
80+
/// [`ndarray::array`](https://docs.rs/ndarray/latest/ndarray/macro.array.html).
81+
///
82+
/// # Example
83+
/// ```
84+
/// pyo3::Python::with_gil(|py| {
85+
/// let array = numpy::pyarray![py, [1, 2], [3, 4]];
86+
/// assert_eq!(
87+
/// array.readonly().as_array(),
88+
/// ndarray::array![[1, 2], [3, 4]]
89+
/// );
90+
/// });
7891
#[macro_export]
7992
macro_rules! pyarray {
8093
($py: ident, $([$([$($x:expr),* $(,)*]),+ $(,)*]),+ $(,)*) => {{

src/sum_products.rs

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1-
use crate::npyffi::{PyArrayObject, NPY_CASTING, NPY_ORDER};
1+
use crate::npyffi::{NPY_CASTING, NPY_ORDER};
22
use crate::{Element, PyArray, PY_ARRAY_API};
3-
use ndarray::Dimension;
3+
use ndarray::{Dimension, IxDyn};
44
use pyo3::{AsPyPointer, FromPyPointer, PyAny, PyNativeType, PyResult};
5+
use std::ffi::CStr;
56

67
/// Return the inner product of two arrays.
8+
///
9+
/// # Example
10+
/// ```
11+
/// pyo3::Python::with_gil(|py| {
12+
/// let array = numpy::pyarray![py, 1, 2, 3];
13+
/// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14+
/// assert_eq!(inner.item(), 14);
15+
/// });
16+
/// ```
717
pub fn inner<'py, T, DIN1, DIN2, DOUT>(
818
array1: &'py PyArray<T, DIN1>,
919
array2: &'py PyArray<T, DIN2>,
@@ -14,12 +24,27 @@ where
1424
DOUT: Dimension,
1525
T: Element,
1626
{
17-
let result = unsafe { PY_ARRAY_API.PyArray_InnerProduct(array1.as_ptr(), array2.as_ptr()) };
18-
let obj = unsafe { PyAny::from_owned_ptr_or_err(array1.py(), result)? };
27+
let obj = unsafe {
28+
let result = PY_ARRAY_API.PyArray_InnerProduct(array1.as_ptr(), array2.as_ptr());
29+
PyAny::from_owned_ptr_or_err(array1.py(), result)?
30+
};
1931
obj.extract()
2032
}
2133

2234
/// Return the dot product of two arrays.
35+
///
36+
/// # Example
37+
/// ```
38+
/// pyo3::Python::with_gil(|py| {
39+
/// let a = numpy::pyarray![py, [1, 0], [0, 1]];
40+
/// let b = numpy::pyarray![py, [4, 1], [2, 2]];
41+
/// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
42+
/// assert_eq!(
43+
/// dot.readonly().as_array(),
44+
/// ndarray::array![[4, 1], [2, 2]]
45+
/// );
46+
/// });
47+
/// ```
2348
pub fn dot<'py, T, DIN1, DIN2, DOUT>(
2449
array1: &'py PyArray<T, DIN1>,
2550
array2: &'py PyArray<T, DIN2>,
@@ -30,39 +55,67 @@ where
3055
DOUT: Dimension,
3156
T: Element,
3257
{
33-
let result = unsafe { PY_ARRAY_API.PyArray_MatrixProduct(array1.as_ptr(), array2.as_ptr()) };
34-
let obj = unsafe { PyAny::from_owned_ptr_or_err(array1.py(), result)? };
58+
let obj = unsafe {
59+
let result = PY_ARRAY_API.PyArray_MatrixProduct(array1.as_ptr(), array2.as_ptr());
60+
PyAny::from_owned_ptr_or_err(array1.py(), result)?
61+
};
3562
obj.extract()
3663
}
3764

38-
pub unsafe fn einsum_impl<'py, T, DIN, DOUT>(
39-
dummy_array: &'py PyArray<T, DIN>,
65+
/// Return the Einstein summation convention of given tensors.
66+
///
67+
/// We also provide the [einsum macro](./macro.einsum.html).
68+
pub fn einsum_impl<'py, T, DOUT>(
4069
subscripts: &str,
41-
arrays: &[*mut PyArrayObject],
70+
arrays: &[&'py PyArray<T, IxDyn>],
4271
) -> PyResult<&'py PyArray<T, DOUT>>
4372
where
44-
DIN: Dimension,
4573
DOUT: Dimension,
4674
T: Element,
4775
{
48-
let subscripts = std::ffi::CStr::from_bytes_with_nul(subscripts.as_bytes()).unwrap();
49-
let result = PY_ARRAY_API.PyArray_EinsteinSum(
50-
subscripts.as_ptr() as _,
51-
arrays.len() as _,
52-
arrays.as_ptr() as _,
53-
std::ptr::null_mut(),
54-
NPY_ORDER::NPY_KEEPORDER,
55-
NPY_CASTING::NPY_NO_CASTING,
56-
std::ptr::null_mut(),
57-
);
58-
let obj = PyAny::from_owned_ptr_or_err(dummy_array.py(), result)?;
76+
let subscripts: std::borrow::Cow<CStr> = if subscripts.ends_with("\0") {
77+
CStr::from_bytes_with_nul(subscripts.as_bytes())
78+
.unwrap()
79+
.into()
80+
} else {
81+
std::ffi::CString::new(subscripts).unwrap().into()
82+
};
83+
let obj = unsafe {
84+
let result = PY_ARRAY_API.PyArray_EinsteinSum(
85+
subscripts.as_ptr() as _,
86+
arrays.len() as _,
87+
arrays.as_ptr() as _,
88+
std::ptr::null_mut(),
89+
NPY_ORDER::NPY_KEEPORDER,
90+
NPY_CASTING::NPY_NO_CASTING,
91+
std::ptr::null_mut(),
92+
);
93+
PyAny::from_owned_ptr_or_err(arrays[0].py(), result)?
94+
};
5995
obj.extract()
6096
}
6197

98+
/// Return the Einstein summation convention of given tensors.
99+
///
100+
/// For more about the Einstein summation convention, you may reffer to
101+
/// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
102+
///
103+
/// # Example
104+
/// ```
105+
/// pyo3::Python::with_gil(|py| {
106+
/// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
107+
/// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
108+
/// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
109+
/// assert_eq!(
110+
/// einsum.readonly().as_array(),
111+
/// ndarray::array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
112+
/// );
113+
/// });
114+
/// ```
62115
#[macro_export]
63116
macro_rules! einsum {
64-
($subscripts: literal, $first_array: ident $(,$array: ident)* $(,)*) => {{
65-
let arrays = [$first_array.as_array_ptr(), $($array.as_array_ptr(),)*];
66-
unsafe { $crate::einsum_impl($first_array, concat!($subscripts, "\0"), &arrays) }
117+
($subscripts: literal $(,$array: ident)+ $(,)*) => {{
118+
let arrays = [$($array.to_dyn(),)+];
119+
unsafe { $crate::einsum_impl(concat!($subscripts, "\0"), &arrays) }
67120
}};
68121
}

0 commit comments

Comments
 (0)