45
45
46
46
# pylint: disable=no-name-in-module
47
47
from .dpnp_algo import (
48
- dpnp_inner ,
49
48
dpnp_kron ,
50
49
)
51
50
from .dpnp_utils import (
@@ -218,43 +217,92 @@ def einsum_path(*args, **kwargs):
218
217
return call_origin (numpy .einsum_path , * args , ** kwargs )
219
218
220
219
221
- def inner (x1 , x2 , ** kwargs ):
220
+ def inner (a , b ):
222
221
"""
223
222
Returns the inner product of two arrays.
224
223
225
224
For full documentation refer to :obj:`numpy.inner`.
226
225
227
- Limitations
228
- -----------
229
- Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`.
230
- Keyword argument `kwargs` is currently unsupported.
231
- Otherwise the functions will be executed sequentially on CPU.
226
+ Parameters
227
+ ----------
228
+ a : {dpnp.ndarray, usm_ndarray, scalar}
229
+ First input array. Both inputs `a` and `b` can not be scalars
230
+ at the same time.
231
+ b : {dpnp.ndarray, usm_ndarray, scalar}
232
+ Second input array. Both inputs `a` and `b` can not be scalars
233
+ at the same time.
234
+
235
+ Returns
236
+ -------
237
+ out : dpnp.ndarray
238
+ If either `a` or `b` is a scalar, the shape of the returned arrays
239
+ matches that of the array between `a` and `b`, whichever is an array.
240
+ If `a` and `b` are both 1-D arrays then a 0-d array is returned;
241
+ otherwise an array with a shape as
242
+ ``out.shape = (*a.shape[:-1], *b.shape[:-1])`` is returned.
243
+
232
244
233
245
See Also
234
246
--------
235
- :obj:`dpnp.einsum` : Evaluates the Einstein summation convention
236
- on the operands.
237
- :obj:`dpnp.dot` : Returns the dot product of two arrays.
238
- :obj:`dpnp.tensordot` : Compute tensor dot product along specified axes.
239
- Input array data types are limited by supported DPNP :ref:`Data types`.
247
+ :obj:`dpnp.einsum` : Einstein summation convention..
248
+ :obj:`dpnp.dot` : Generalised matrix product,
249
+ using second last dimension of `b`.
250
+ :obj:`dpnp.tensordot` : Sum products over arbitrary axes.
240
251
241
252
Examples
242
253
--------
254
+ # Ordinary inner product for vectors
255
+
243
256
>>> import dpnp as np
244
- >>> a = np.array([1,2, 3])
257
+ >>> a = np.array([1, 2, 3])
245
258
>>> b = np.array([0, 1, 0])
246
- >>> result = np.inner(a, b)
247
- >>> [x for x in result]
248
- [2]
259
+ >>> np.inner(a, b)
260
+ array(2)
261
+
262
+ # Some multidimensional examples
263
+
264
+ >>> a = np.arange(24).reshape((2,3,4))
265
+ >>> b = np.arange(4)
266
+ >>> c = np.inner(a, b)
267
+ >>> c.shape
268
+ (2, 3)
269
+ >>> c
270
+ array([[ 14, 38, 62],
271
+ [86, 110, 134]])
272
+
273
+ >>> a = np.arange(2).reshape((1,1,2))
274
+ >>> b = np.arange(6).reshape((3,2))
275
+ >>> c = np.inner(a, b)
276
+ >>> c.shape
277
+ (1, 1, 3)
278
+ >>> c
279
+ array([[[1, 3, 5]]])
280
+
281
+ An example where `b` is a scalar
282
+
283
+ >>> np.inner(np.eye(2), 7)
284
+ array([[7., 0.],
285
+ [0., 7.]])
249
286
250
287
"""
251
288
252
- x1_desc = dpnp .get_dpnp_descriptor (x1 , copy_when_nondefault_queue = False )
253
- x2_desc = dpnp .get_dpnp_descriptor (x2 , copy_when_nondefault_queue = False )
254
- if x1_desc and x2_desc and not kwargs :
255
- return dpnp_inner (x1_desc , x2_desc ).get_pyobj ()
289
+ dpnp .check_supported_arrays_type (a , b , scalar_type = True )
290
+
291
+ if dpnp .isscalar (a ) or dpnp .isscalar (b ):
292
+ return dpnp .multiply (a , b )
293
+
294
+ if a .ndim == 0 or b .ndim == 0 :
295
+ return dpnp .multiply (a , b )
296
+
297
+ if a .shape [- 1 ] != b .shape [- 1 ]:
298
+ raise ValueError (
299
+ "shape of input arrays is not similar at the last axis."
300
+ )
301
+
302
+ if a .ndim == 1 and b .ndim == 1 :
303
+ return dpnp_dot (a , b )
256
304
257
- return call_origin ( numpy . inner , x1 , x2 , ** kwargs )
305
+ return dpnp . tensordot ( a , b , axes = ( - 1 , - 1 ) )
258
306
259
307
260
308
def kron (x1 , x2 ):
@@ -567,16 +615,20 @@ def tensordot(a, b, axes=2):
567
615
568
616
dpnp .check_supported_arrays_type (a , b , scalar_type = True )
569
617
570
- if dpnp .isscalar (a ):
571
- a = dpnp .array (a , sycl_queue = b .sycl_queue , usm_type = b .usm_type )
572
- elif dpnp .isscalar (b ):
573
- b = dpnp .array (b , sycl_queue = a .sycl_queue , usm_type = a .usm_type )
618
+ if dpnp .isscalar (a ) or dpnp .isscalar (b ):
619
+ if not isinstance (axes , int ) or axes != 0 :
620
+ raise ValueError (
621
+ "One of the inputs is scalar, axes should be zero."
622
+ )
623
+ return dpnp .multiply (a , b )
574
624
575
625
try :
576
626
iter (axes )
577
627
except Exception as e : # pylint: disable=broad-exception-caught
578
628
if not isinstance (axes , int ):
579
629
raise TypeError ("Axes must be an integer." ) from e
630
+ if axes < 0 :
631
+ raise ValueError ("Axes must be a nonnegative integer." ) from e
580
632
axes_a = tuple (range (- axes , 0 ))
581
633
axes_b = tuple (range (0 , axes ))
582
634
else :
@@ -590,6 +642,15 @@ def tensordot(a, b, axes=2):
590
642
if len (axes_a ) != len (axes_b ):
591
643
raise ValueError ("Axes length mismatch." )
592
644
645
+ # Make the axes non-negative
646
+ a_ndim = a .ndim
647
+ b_ndim = b .ndim
648
+ axes_a = normalize_axis_tuple (axes_a , a_ndim , "axis_a" )
649
+ axes_b = normalize_axis_tuple (axes_b , b_ndim , "axis_b" )
650
+
651
+ if a .ndim == 0 or b .ndim == 0 :
652
+ return dpnp .multiply (a , b )
653
+
593
654
a_shape = a .shape
594
655
b_shape = b .shape
595
656
for axis_a , axis_b in zip (axes_a , axes_b ):
@@ -598,12 +659,6 @@ def tensordot(a, b, axes=2):
598
659
"shape of input arrays is not similar at requested axes."
599
660
)
600
661
601
- # Make the axes non-negative
602
- a_ndim = a .ndim
603
- b_ndim = b .ndim
604
- axes_a = normalize_axis_tuple (axes_a , a_ndim , "axis" )
605
- axes_b = normalize_axis_tuple (axes_b , b_ndim , "axis" )
606
-
607
662
# Move the axes to sum over, to the end of "a"
608
663
notin = tuple (k for k in range (a_ndim ) if k not in axes_a )
609
664
newaxes_a = notin + axes_a
0 commit comments