1- use crate :: npyffi:: { PyArrayObject , NPY_CASTING , NPY_ORDER } ;
1+ use crate :: npyffi:: { NPY_CASTING , NPY_ORDER } ;
22use crate :: { Element , PyArray , PY_ARRAY_API } ;
3- use ndarray:: Dimension ;
3+ use ndarray:: { Dimension , IxDyn } ;
44use pyo3:: { AsPyPointer , FromPyPointer , PyAny , PyNativeType , PyResult } ;
5+ use std:: ffi:: CStr ;
56
67/// Return the inner product of two arrays.
8+ ///
9+ /// # Example
10+ /// ```
11+ /// pyo3::Python::with_gil(|py| {
12+ /// let array = numpy::pyarray![py, 1, 2, 3];
13+ /// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14+ /// assert_eq!(inner.item(), 14);
15+ /// });
16+ /// ```
717pub fn inner < ' py , T , DIN1 , DIN2 , DOUT > (
818 array1 : & ' py PyArray < T , DIN1 > ,
919 array2 : & ' py PyArray < T , DIN2 > ,
@@ -14,12 +24,27 @@ where
1424 DOUT : Dimension ,
1525 T : Element ,
1626{
17- let result = unsafe { PY_ARRAY_API . PyArray_InnerProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) } ;
18- let obj = unsafe { PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ? } ;
27+ let obj = unsafe {
28+ let result = PY_ARRAY_API . PyArray_InnerProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) ;
29+ PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ?
30+ } ;
1931 obj. extract ( )
2032}
2133
2234/// Return the dot product of two arrays.
35+ ///
36+ /// # Example
37+ /// ```
38+ /// pyo3::Python::with_gil(|py| {
39+ /// let a = numpy::pyarray![py, [1, 0], [0, 1]];
40+ /// let b = numpy::pyarray![py, [4, 1], [2, 2]];
41+ /// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
42+ /// assert_eq!(
43+ /// dot.readonly().as_array(),
44+ /// ndarray::array![[4, 1], [2, 2]]
45+ /// );
46+ /// });
47+ /// ```
2348pub fn dot < ' py , T , DIN1 , DIN2 , DOUT > (
2449 array1 : & ' py PyArray < T , DIN1 > ,
2550 array2 : & ' py PyArray < T , DIN2 > ,
@@ -30,39 +55,67 @@ where
3055 DOUT : Dimension ,
3156 T : Element ,
3257{
33- let result = unsafe { PY_ARRAY_API . PyArray_MatrixProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) } ;
34- let obj = unsafe { PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ? } ;
58+ let obj = unsafe {
59+ let result = PY_ARRAY_API . PyArray_MatrixProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) ;
60+ PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ?
61+ } ;
3562 obj. extract ( )
3663}
3764
38- pub unsafe fn einsum_impl < ' py , T , DIN , DOUT > (
39- dummy_array : & ' py PyArray < T , DIN > ,
65+ /// Return the Einstein summation convention of given tensors.
66+ ///
67+ /// We also provide the [einsum macro](./macro.einsum.html).
68+ pub fn einsum_impl < ' py , T , DOUT > (
4069 subscripts : & str ,
41- arrays : & [ * mut PyArrayObject ] ,
70+ arrays : & [ & ' py PyArray < T , IxDyn > ] ,
4271) -> PyResult < & ' py PyArray < T , DOUT > >
4372where
44- DIN : Dimension ,
4573 DOUT : Dimension ,
4674 T : Element ,
4775{
48- let subscripts = std:: ffi:: CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) ) . unwrap ( ) ;
49- let result = PY_ARRAY_API . PyArray_EinsteinSum (
50- subscripts. as_ptr ( ) as _ ,
51- arrays. len ( ) as _ ,
52- arrays. as_ptr ( ) as _ ,
53- std:: ptr:: null_mut ( ) ,
54- NPY_ORDER :: NPY_KEEPORDER ,
55- NPY_CASTING :: NPY_NO_CASTING ,
56- std:: ptr:: null_mut ( ) ,
57- ) ;
58- let obj = PyAny :: from_owned_ptr_or_err ( dummy_array. py ( ) , result) ?;
76+ let subscripts: std:: borrow:: Cow < CStr > = if subscripts. ends_with ( "\0 " ) {
77+ CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) )
78+ . unwrap ( )
79+ . into ( )
80+ } else {
81+ std:: ffi:: CString :: new ( subscripts) . unwrap ( ) . into ( )
82+ } ;
83+ let obj = unsafe {
84+ let result = PY_ARRAY_API . PyArray_EinsteinSum (
85+ subscripts. as_ptr ( ) as _ ,
86+ arrays. len ( ) as _ ,
87+ arrays. as_ptr ( ) as _ ,
88+ std:: ptr:: null_mut ( ) ,
89+ NPY_ORDER :: NPY_KEEPORDER ,
90+ NPY_CASTING :: NPY_NO_CASTING ,
91+ std:: ptr:: null_mut ( ) ,
92+ ) ;
93+ PyAny :: from_owned_ptr_or_err ( arrays[ 0 ] . py ( ) , result) ?
94+ } ;
5995 obj. extract ( )
6096}
6197
98+ /// Return the Einstein summation convention of given tensors.
99+ ///
100+ /// For more about the Einstein summation convention, you may reffer to
101+ /// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
102+ ///
103+ /// # Example
104+ /// ```
105+ /// pyo3::Python::with_gil(|py| {
106+ /// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
107+ /// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
108+ /// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
109+ /// assert_eq!(
110+ /// einsum.readonly().as_array(),
111+ /// ndarray::array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
112+ /// );
113+ /// });
114+ /// ```
62115#[ macro_export]
63116macro_rules! einsum {
64- ( $subscripts: literal, $first_array : ident $ ( , $array: ident) * $( , ) * ) => { {
65- let arrays = [ $first_array . as_array_ptr ( ) , $ ( $ array. as_array_ptr ( ) , ) * ] ;
66- unsafe { $crate:: einsum_impl( $first_array , concat!( $subscripts, "\0 " ) , & arrays) }
117+ ( $subscripts: literal $ ( , $array: ident) + $( , ) * ) => { {
118+ let arrays = [ $( $ array. to_dyn ( ) , ) + ] ;
119+ unsafe { $crate:: einsum_impl( concat!( $subscripts, "\0 " ) , & arrays) }
67120 } } ;
68121}
0 commit comments