1
- use crate :: npyffi:: { PyArrayObject , NPY_CASTING , NPY_ORDER } ;
1
+ use crate :: npyffi:: { NPY_CASTING , NPY_ORDER } ;
2
2
use crate :: { Element , PyArray , PY_ARRAY_API } ;
3
- use ndarray:: Dimension ;
3
+ use ndarray:: { Dimension , IxDyn } ;
4
4
use pyo3:: { AsPyPointer , FromPyPointer , PyAny , PyNativeType , PyResult } ;
5
+ use std:: ffi:: CStr ;
5
6
6
7
/// Return the inner product of two arrays.
8
+ ///
9
+ /// # Example
10
+ /// ```
11
+ /// pyo3::Python::with_gil(|py| {
12
+ /// let array = numpy::pyarray![py, 1, 2, 3];
13
+ /// let inner: &numpy::PyArray0::<_> = numpy::inner(array, array).unwrap();
14
+ /// assert_eq!(inner.item(), 14);
15
+ /// });
16
+ /// ```
7
17
pub fn inner < ' py , T , DIN1 , DIN2 , DOUT > (
8
18
array1 : & ' py PyArray < T , DIN1 > ,
9
19
array2 : & ' py PyArray < T , DIN2 > ,
@@ -14,12 +24,27 @@ where
14
24
DOUT : Dimension ,
15
25
T : Element ,
16
26
{
17
- let result = unsafe { PY_ARRAY_API . PyArray_InnerProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) } ;
18
- let obj = unsafe { PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ? } ;
27
+ let obj = unsafe {
28
+ let result = PY_ARRAY_API . PyArray_InnerProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) ;
29
+ PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ?
30
+ } ;
19
31
obj. extract ( )
20
32
}
21
33
22
34
/// Return the dot product of two arrays.
35
+ ///
36
+ /// # Example
37
+ /// ```
38
+ /// pyo3::Python::with_gil(|py| {
39
+ /// let a = numpy::pyarray![py, [1, 0], [0, 1]];
40
+ /// let b = numpy::pyarray![py, [4, 1], [2, 2]];
41
+ /// let dot: &numpy::PyArray2::<_> = numpy::dot(a, b).unwrap();
42
+ /// assert_eq!(
43
+ /// dot.readonly().as_array(),
44
+ /// ndarray::array![[4, 1], [2, 2]]
45
+ /// );
46
+ /// });
47
+ /// ```
23
48
pub fn dot < ' py , T , DIN1 , DIN2 , DOUT > (
24
49
array1 : & ' py PyArray < T , DIN1 > ,
25
50
array2 : & ' py PyArray < T , DIN2 > ,
@@ -30,39 +55,67 @@ where
30
55
DOUT : Dimension ,
31
56
T : Element ,
32
57
{
33
- let result = unsafe { PY_ARRAY_API . PyArray_MatrixProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) } ;
34
- let obj = unsafe { PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ? } ;
58
+ let obj = unsafe {
59
+ let result = PY_ARRAY_API . PyArray_MatrixProduct ( array1. as_ptr ( ) , array2. as_ptr ( ) ) ;
60
+ PyAny :: from_owned_ptr_or_err ( array1. py ( ) , result) ?
61
+ } ;
35
62
obj. extract ( )
36
63
}
37
64
38
- pub unsafe fn einsum_impl < ' py , T , DIN , DOUT > (
39
- dummy_array : & ' py PyArray < T , DIN > ,
65
+ /// Return the Einstein summation convention of given tensors.
66
+ ///
67
+ /// We also provide the [einsum macro](./macro.einsum.html).
68
+ pub fn einsum_impl < ' py , T , DOUT > (
40
69
subscripts : & str ,
41
- arrays : & [ * mut PyArrayObject ] ,
70
+ arrays : & [ & ' py PyArray < T , IxDyn > ] ,
42
71
) -> PyResult < & ' py PyArray < T , DOUT > >
43
72
where
44
- DIN : Dimension ,
45
73
DOUT : Dimension ,
46
74
T : Element ,
47
75
{
48
- let subscripts = std:: ffi:: CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) ) . unwrap ( ) ;
49
- let result = PY_ARRAY_API . PyArray_EinsteinSum (
50
- subscripts. as_ptr ( ) as _ ,
51
- arrays. len ( ) as _ ,
52
- arrays. as_ptr ( ) as _ ,
53
- std:: ptr:: null_mut ( ) ,
54
- NPY_ORDER :: NPY_KEEPORDER ,
55
- NPY_CASTING :: NPY_NO_CASTING ,
56
- std:: ptr:: null_mut ( ) ,
57
- ) ;
58
- let obj = PyAny :: from_owned_ptr_or_err ( dummy_array. py ( ) , result) ?;
76
+ let subscripts: std:: borrow:: Cow < CStr > = if subscripts. ends_with ( "\0 " ) {
77
+ CStr :: from_bytes_with_nul ( subscripts. as_bytes ( ) )
78
+ . unwrap ( )
79
+ . into ( )
80
+ } else {
81
+ std:: ffi:: CString :: new ( subscripts) . unwrap ( ) . into ( )
82
+ } ;
83
+ let obj = unsafe {
84
+ let result = PY_ARRAY_API . PyArray_EinsteinSum (
85
+ subscripts. as_ptr ( ) as _ ,
86
+ arrays. len ( ) as _ ,
87
+ arrays. as_ptr ( ) as _ ,
88
+ std:: ptr:: null_mut ( ) ,
89
+ NPY_ORDER :: NPY_KEEPORDER ,
90
+ NPY_CASTING :: NPY_NO_CASTING ,
91
+ std:: ptr:: null_mut ( ) ,
92
+ ) ;
93
+ PyAny :: from_owned_ptr_or_err ( arrays[ 0 ] . py ( ) , result) ?
94
+ } ;
59
95
obj. extract ( )
60
96
}
61
97
98
+ /// Return the Einstein summation convention of given tensors.
99
+ ///
100
+ /// For more about the Einstein summation convention, you may reffer to
101
+ /// [the numpy document](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
102
+ ///
103
+ /// # Example
104
+ /// ```
105
+ /// pyo3::Python::with_gil(|py| {
106
+ /// let a = numpy::PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
107
+ /// let b = numpy::pyarray![py, [20, 30], [40, 50], [60, 70]];
108
+ /// let einsum = numpy::einsum!("ijk,ji->ik", a, b).unwrap();
109
+ /// assert_eq!(
110
+ /// einsum.readonly().as_array(),
111
+ /// ndarray::array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
112
+ /// );
113
+ /// });
114
+ /// ```
62
115
#[ macro_export]
63
116
macro_rules! einsum {
64
- ( $subscripts: literal, $first_array : ident $ ( , $array: ident) * $( , ) * ) => { {
65
- let arrays = [ $first_array . as_array_ptr ( ) , $ ( $ array. as_array_ptr ( ) , ) * ] ;
66
- unsafe { $crate:: einsum_impl( $first_array , concat!( $subscripts, "\0 " ) , & arrays) }
117
+ ( $subscripts: literal $ ( , $array: ident) + $( , ) * ) => { {
118
+ let arrays = [ $( $ array. to_dyn ( ) , ) + ] ;
119
+ unsafe { $crate:: einsum_impl( concat!( $subscripts, "\0 " ) , & arrays) }
67
120
} } ;
68
121
}
0 commit comments