Skip to content

Commit b8f051d

Browse files
vtavanaantonwolfy
andauthored
update dpnp.inner implementation (#1726)
* update dpnp.inner * address comments * fix typo --------- Co-authored-by: Anton <[email protected]>
1 parent 8fa2166 commit b8f051d

File tree

7 files changed

+232
-113
lines changed

7 files changed

+232
-113
lines changed

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ and the rest of the library
3636
# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file
3737

3838
__all__ += [
39-
"dpnp_inner",
4039
"dpnp_kron",
4140
]
4241

@@ -48,80 +47,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_shapes_t)(c_dpctl.DPCTLSyclQue
4847
const c_dpctl.DPCTLEventVectorRef)
4948

5049

51-
cpdef utils.dpnp_descriptor dpnp_inner(dpnp_descriptor array1, dpnp_descriptor array2):
52-
result_type = numpy.promote_types(array1.dtype, array1.dtype)
53-
54-
assert(len(array1.shape) == len(array2.shape))
55-
56-
cdef shape_type_c array1_no_last_axes = array1.shape[:-1]
57-
cdef shape_type_c array2_no_last_axes = array2.shape[:-1]
58-
59-
cdef shape_type_c result_shape = array1_no_last_axes
60-
result_shape.insert(result_shape.end(), array2_no_last_axes.begin(), array2_no_last_axes.end())
61-
62-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(array1, array2)
63-
64-
# create result array with type given by FPTR data
65-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(result_shape,
66-
result_type,
67-
None,
68-
device=result_sycl_device,
69-
usm_type=result_usm_type,
70-
sycl_queue=result_sycl_queue)
71-
72-
# calculate input arrays offsets
73-
cdef shape_type_c array1_offsets = [1] * len(array1.shape)
74-
cdef shape_type_c array2_offsets = [1] * len(array2.shape)
75-
cdef size_t acc1 = 1
76-
cdef size_t acc2 = 1
77-
for axis in range(len(array1.shape) - 1, -1, -1):
78-
array1_offsets[axis] = acc1
79-
array2_offsets[axis] = acc2
80-
acc1 *= array1.shape[axis]
81-
acc2 *= array2.shape[axis]
82-
83-
cdef shape_type_c result_shape_offsets = [1] * len(result.shape)
84-
acc = 1
85-
for i in range(len(result.shape) - 1, -1, -1):
86-
result_shape_offsets[i] = acc
87-
acc *= result.shape[i]
88-
89-
cdef shape_type_c xyz
90-
cdef size_t array1_lin_index_base
91-
cdef size_t array2_lin_index_base
92-
cdef size_t axis2
93-
cdef long remainder
94-
cdef long quotient
95-
96-
result_flatiter = result.get_pyobj().flat
97-
array1_flatiter = array1.get_pyobj().flat
98-
array2_flatiter = array2.get_pyobj().flat
99-
100-
for idx1 in range(result.size):
101-
# reconstruct x,y,z from linear index
102-
xyz.clear()
103-
remainder = idx1
104-
for i in result_shape_offsets:
105-
quotient, remainder = divmod(remainder, i)
106-
xyz.push_back(quotient)
107-
108-
# calculate linear base input index
109-
array1_lin_index_base = 0
110-
array2_lin_index_base = 0
111-
for axis in range(len(array1_offsets) - 1):
112-
axis2 = axis + (len(xyz) / 2)
113-
array1_lin_index_base += array1_offsets[axis] * xyz[axis]
114-
array2_lin_index_base += array2_offsets[axis] * xyz[axis2]
115-
116-
# do inner product
117-
result_flatiter[idx1] = 0
118-
for idx2 in range(array1.shape[-1]):
119-
result_flatiter[idx1] += array1_flatiter[array1_lin_index_base + idx2] * \
120-
array2_flatiter[array2_lin_index_base + idx2]
121-
122-
return result
123-
124-
12550
cpdef utils.dpnp_descriptor dpnp_kron(dpnp_descriptor in_array1, dpnp_descriptor in_array2):
12651
cdef size_t ndim = max(in_array1.ndim, in_array2.ndim)
12752

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 86 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545

4646
# pylint: disable=no-name-in-module
4747
from .dpnp_algo import (
48-
dpnp_inner,
4948
dpnp_kron,
5049
)
5150
from .dpnp_utils import (
@@ -218,43 +217,92 @@ def einsum_path(*args, **kwargs):
218217
return call_origin(numpy.einsum_path, *args, **kwargs)
219218

220219

221-
def inner(x1, x2, **kwargs):
220+
def inner(a, b):
222221
"""
223222
Returns the inner product of two arrays.
224223
225224
For full documentation refer to :obj:`numpy.inner`.
226225
227-
Limitations
228-
-----------
229-
Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`.
230-
Keyword argument `kwargs` is currently unsupported.
231-
Otherwise the functions will be executed sequentially on CPU.
226+
Parameters
227+
----------
228+
a : {dpnp.ndarray, usm_ndarray, scalar}
229+
First input array. Both inputs `a` and `b` can not be scalars
230+
at the same time.
231+
b : {dpnp.ndarray, usm_ndarray, scalar}
232+
Second input array. Both inputs `a` and `b` can not be scalars
233+
at the same time.
234+
235+
Returns
236+
-------
237+
out : dpnp.ndarray
238+
If either `a` or `b` is a scalar, the shape of the returned arrays
239+
matches that of the array between `a` and `b`, whichever is an array.
240+
If `a` and `b` are both 1-D arrays then a 0-d array is returned;
241+
otherwise an array with a shape as
242+
``out.shape = (*a.shape[:-1], *b.shape[:-1])`` is returned.
243+
232244
233245
See Also
234246
--------
235-
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention
236-
on the operands.
237-
:obj:`dpnp.dot` : Returns the dot product of two arrays.
238-
:obj:`dpnp.tensordot` : Compute tensor dot product along specified axes.
239-
Input array data types are limited by supported DPNP :ref:`Data types`.
247+
:obj:`dpnp.einsum` : Einstein summation convention..
248+
:obj:`dpnp.dot` : Generalised matrix product,
249+
using second last dimension of `b`.
250+
:obj:`dpnp.tensordot` : Sum products over arbitrary axes.
240251
241252
Examples
242253
--------
254+
# Ordinary inner product for vectors
255+
243256
>>> import dpnp as np
244-
>>> a = np.array([1,2,3])
257+
>>> a = np.array([1, 2, 3])
245258
>>> b = np.array([0, 1, 0])
246-
>>> result = np.inner(a, b)
247-
>>> [x for x in result]
248-
[2]
259+
>>> np.inner(a, b)
260+
array(2)
261+
262+
# Some multidimensional examples
263+
264+
>>> a = np.arange(24).reshape((2,3,4))
265+
>>> b = np.arange(4)
266+
>>> c = np.inner(a, b)
267+
>>> c.shape
268+
(2, 3)
269+
>>> c
270+
array([[ 14, 38, 62],
271+
[86, 110, 134]])
272+
273+
>>> a = np.arange(2).reshape((1,1,2))
274+
>>> b = np.arange(6).reshape((3,2))
275+
>>> c = np.inner(a, b)
276+
>>> c.shape
277+
(1, 1, 3)
278+
>>> c
279+
array([[[1, 3, 5]]])
280+
281+
An example where `b` is a scalar
282+
283+
>>> np.inner(np.eye(2), 7)
284+
array([[7., 0.],
285+
[0., 7.]])
249286
250287
"""
251288

252-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
253-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
254-
if x1_desc and x2_desc and not kwargs:
255-
return dpnp_inner(x1_desc, x2_desc).get_pyobj()
289+
dpnp.check_supported_arrays_type(a, b, scalar_type=True)
290+
291+
if dpnp.isscalar(a) or dpnp.isscalar(b):
292+
return dpnp.multiply(a, b)
293+
294+
if a.ndim == 0 or b.ndim == 0:
295+
return dpnp.multiply(a, b)
296+
297+
if a.shape[-1] != b.shape[-1]:
298+
raise ValueError(
299+
"shape of input arrays is not similar at the last axis."
300+
)
301+
302+
if a.ndim == 1 and b.ndim == 1:
303+
return dpnp_dot(a, b)
256304

257-
return call_origin(numpy.inner, x1, x2, **kwargs)
305+
return dpnp.tensordot(a, b, axes=(-1, -1))
258306

259307

260308
def kron(x1, x2):
@@ -567,16 +615,20 @@ def tensordot(a, b, axes=2):
567615

568616
dpnp.check_supported_arrays_type(a, b, scalar_type=True)
569617

570-
if dpnp.isscalar(a):
571-
a = dpnp.array(a, sycl_queue=b.sycl_queue, usm_type=b.usm_type)
572-
elif dpnp.isscalar(b):
573-
b = dpnp.array(b, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
618+
if dpnp.isscalar(a) or dpnp.isscalar(b):
619+
if not isinstance(axes, int) or axes != 0:
620+
raise ValueError(
621+
"One of the inputs is scalar, axes should be zero."
622+
)
623+
return dpnp.multiply(a, b)
574624

575625
try:
576626
iter(axes)
577627
except Exception as e: # pylint: disable=broad-exception-caught
578628
if not isinstance(axes, int):
579629
raise TypeError("Axes must be an integer.") from e
630+
if axes < 0:
631+
raise ValueError("Axes must be a nonnegative integer.") from e
580632
axes_a = tuple(range(-axes, 0))
581633
axes_b = tuple(range(0, axes))
582634
else:
@@ -590,6 +642,15 @@ def tensordot(a, b, axes=2):
590642
if len(axes_a) != len(axes_b):
591643
raise ValueError("Axes length mismatch.")
592644

645+
# Make the axes non-negative
646+
a_ndim = a.ndim
647+
b_ndim = b.ndim
648+
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
649+
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")
650+
651+
if a.ndim == 0 or b.ndim == 0:
652+
return dpnp.multiply(a, b)
653+
593654
a_shape = a.shape
594655
b_shape = b.shape
595656
for axis_a, axis_b in zip(axes_a, axes_b):
@@ -598,12 +659,6 @@ def tensordot(a, b, axes=2):
598659
"shape of input arrays is not similar at requested axes."
599660
)
600661

601-
# Make the axes non-negative
602-
a_ndim = a.ndim
603-
b_ndim = b.ndim
604-
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis")
605-
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis")
606-
607662
# Move the axes to sum over, to the end of "a"
608663
notin = tuple(k for k in range(a_ndim) if k not in axes_a)
609664
newaxes_a = notin + axes_a

dpnp/dpnp_iface_mathematical.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,10 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
770770
raise TypeError(f"axis should be an integer but got, {type(axis)}.")
771771
axisa, axisb, axisc = (axis,) * 3
772772
dpnp.check_supported_arrays_type(a, b)
773+
if a.dtype == dpnp.bool and b.dtype == dpnp.bool:
774+
raise TypeError(
775+
"Input arrays with boolean data type are not supported."
776+
)
773777
# Check axisa and axisb are within bounds
774778
axisa = normalize_axis_index(axisa, a.ndim, msg_prefix="axisa")
775779
axisb = normalize_axis_index(axisb, b.ndim, msg_prefix="axisb")

0 commit comments

Comments
 (0)