Skip to content

Commit d7400a2

Browse files
npolina4antonwolfy
andauthored
Implement outer for element-wise functions (#1813)
* Add ufunc outer * Update dpnp.outer * Changed flatten to ravel in outer function to avoid unnecessary copy * Update flatten implementation * Update docs for outer * address comments * Update dpnp/dpnp_algo/dpnp_elementwise_common.py --------- Co-authored-by: Anton <[email protected]>
1 parent 31f9405 commit d7400a2

File tree

5 files changed

+175
-47
lines changed

5 files changed

+175
-47
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,107 @@ def __call__(
341341
return out
342342
return dpnp_array._create_from_usm_ndarray(res_usm)
343343

344+
def outer(
345+
self,
346+
x1,
347+
x2,
348+
out=None,
349+
where=True,
350+
order="K",
351+
dtype=None,
352+
subok=True,
353+
**kwargs,
354+
):
355+
"""
356+
Apply the ufunc op to all pairs (a, b) with a in A and b in B.
357+
358+
Parameters
359+
----------
360+
x1 : {dpnp.ndarray, usm_ndarray}
361+
First input array.
362+
x2 : {dpnp.ndarray, usm_ndarray}
363+
Second input array.
364+
out : {None, dpnp.ndarray, usm_ndarray}, optional
365+
Output array to populate.
366+
Array must have the correct shape and the expected data type.
367+
order : {None, "C", "F", "A", "K"}, optional
368+
Memory layout of the newly output array, Cannot be provided
369+
together with `out`. Default: ``"K"``.
370+
dtype : {None, dtype}, optional
371+
If provided, the destination array will have this dtype. Cannot be
372+
provided together with `out`. Default: ``None``.
373+
374+
Returns
375+
-------
376+
out : dpnp.ndarray
377+
Output array. The data type of the returned array is determined by
378+
the Type Promotion Rules.
379+
380+
Limitations
381+
-----------
382+
Parameters `where` and `subok` are supported with their default values.
383+
Keyword argument `kwargs` is currently unsupported.
384+
Otherwise ``NotImplementedError`` exception will be raised.
385+
386+
See also
387+
--------
388+
:obj:`dpnp.outer` : A less powerful version of dpnp.multiply.outer
389+
that ravels all inputs to 1D. This exists primarily
390+
for compatibility with old code.
391+
392+
:obj:`dpnp.tensordot` : dpnp.tensordot(a, b, axes=((), ())) and
393+
dpnp.multiply.outer(a, b) behave same for all
394+
dimensions of a and b.
395+
396+
Examples
397+
--------
398+
>>> import dpnp as np
399+
>>> A = np.array([1, 2, 3])
400+
>>> B = np.array([4, 5, 6])
401+
>>> np.multiply.outer(A, B)
402+
array([[ 4, 5, 6],
403+
[ 8, 10, 12],
404+
[12, 15, 18]])
405+
406+
A multi-dimensional example:
407+
>>> A = np.array([[1, 2, 3], [4, 5, 6]])
408+
>>> A.shape
409+
(2, 3)
410+
>>> B = np.array([[1, 2, 3, 4]])
411+
>>> B.shape
412+
(1, 4)
413+
>>> C = np.multiply.outer(A, B)
414+
>>> C.shape; C
415+
(2, 3, 1, 4)
416+
array([[[[ 1, 2, 3, 4]],
417+
[[ 2, 4, 6, 8]],
418+
[[ 3, 6, 9, 12]]],
419+
[[[ 4, 8, 12, 16]],
420+
[[ 5, 10, 15, 20]],
421+
[[ 6, 12, 18, 24]]]])
422+
423+
"""
424+
425+
dpnp.check_supported_arrays_type(
426+
x1, x2, scalar_type=True, all_scalars=False
427+
)
428+
if dpnp.isscalar(x1) or dpnp.isscalar(x2):
429+
_x1 = x1
430+
_x2 = x2
431+
else:
432+
_x1 = x1[(Ellipsis,) + (None,) * x2.ndim]
433+
_x2 = x2[(None,) * x1.ndim + (Ellipsis,)]
434+
return self.__call__(
435+
_x1,
436+
_x2,
437+
out=out,
438+
where=where,
439+
order=order,
440+
dtype=dtype,
441+
subok=subok,
442+
**kwargs,
443+
)
444+
344445

345446
class DPNPAngle(DPNPUnaryFunc):
346447
"""Class that implements dpnp.angle unary element-wise functions."""

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@
4343

4444
import dpnp
4545

46-
# pylint: disable=no-name-in-module
47-
from .dpnp_utils import (
48-
call_origin,
49-
)
5046
from .dpnp_utils.dpnp_utils_linearalgebra import (
5147
dpnp_dot,
5248
dpnp_einsum,
@@ -851,62 +847,58 @@ def matmul(
851847
)
852848

853849

854-
def outer(x1, x2, out=None):
850+
def outer(a, b, out=None):
855851
"""
856852
Returns the outer product of two arrays.
857853
858854
For full documentation refer to :obj:`numpy.outer`.
859855
860-
Limitations
861-
-----------
862-
Parameters `x1` and `x2` are supported as either scalar,
863-
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`, but both
864-
`x1` and `x2` can not be scalars at the same time. Otherwise
865-
the functions will be executed sequentially on CPU.
866-
Input array data types are limited by supported DPNP :ref:`Data types`.
856+
Parameters
857+
----------
858+
a : {dpnp.ndarray, usm_ndarray}
859+
First input vector. Input is flattened if not already 1-dimensional.
860+
b : {dpnp.ndarray, usm_ndarray}
861+
Second input vector. Input is flattened if not already 1-dimensional.
862+
out : {None, dpnp.ndarray, usm_ndarray}, optional
863+
A location where the result is stored
864+
865+
Returns
866+
-------
867+
out : dpnp.ndarray
868+
out[i, j] = a[i] * b[j]
867869
868870
See Also
869871
--------
870872
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention
871873
on the operands.
872874
:obj:`dpnp.inner` : Returns the inner product of two arrays.
875+
:obj:`dpnp.tensordot` : dpnp.tensordot(a.ravel(), b.ravel(), axes=((), ()))
876+
is the equivalent.
873877
874878
Examples
875879
--------
876880
>>> import dpnp as np
877881
>>> a = np.array([1, 1, 1])
878882
>>> b = np.array([1, 2, 3])
879-
>>> result = np.outer(a, b)
880-
>>> [x for x in result]
883+
>>> np.outer(a, b)
881884
array([[1, 2, 3],
882885
[1, 2, 3],
883886
[1, 2, 3]])
884887
885888
"""
886889

887-
x1_is_scalar = dpnp.isscalar(x1)
888-
x2_is_scalar = dpnp.isscalar(x2)
889-
890-
if x1_is_scalar and x2_is_scalar:
891-
pass
892-
elif not dpnp.is_supported_array_or_scalar(x1):
893-
pass
894-
elif not dpnp.is_supported_array_or_scalar(x2):
895-
pass
890+
dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False)
891+
if dpnp.isscalar(a):
892+
x1 = a
893+
x2 = b.ravel()[None, :]
894+
elif dpnp.isscalar(b):
895+
x1 = a.ravel()[:, None]
896+
x2 = b
896897
else:
897-
x1_in = (
898-
x1
899-
if x1_is_scalar
900-
else (x1.reshape(-1) if x1.ndim > 1 else x1)[:, None]
901-
)
902-
x2_in = (
903-
x2
904-
if x2_is_scalar
905-
else (x2.reshape(-1) if x2.ndim > 1 else x2)[None, :]
906-
)
907-
return dpnp.multiply(x1_in, x2_in, out=out)
898+
x1 = a.ravel()
899+
x2 = b.ravel()
908900

909-
return call_origin(numpy.outer, x1, x2, out=out)
901+
return dpnp.multiply.outer(x1, x2, out=out)
910902

911903

912904
def tensordot(a, b, axes=2):

tests/test_flipping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_arange_4d(self, axis, dtype):
6060
)
6161
def test_lr_equivalent(self, dtype):
6262
dp_a = dpnp.arange(4, dtype=dtype)
63-
dp_a = dp_a[:, dpnp.newaxis] + dp_a[dpnp.newaxis, :]
63+
dp_a = dpnp.add.outer(dp_a, dp_a)
6464
assert_equal(dpnp.flip(dp_a, 1), dpnp.fliplr(dp_a))
6565

6666
np_a = numpy.arange(4, dtype=dtype)
@@ -72,7 +72,7 @@ def test_lr_equivalent(self, dtype):
7272
)
7373
def test_ud_equivalent(self, dtype):
7474
dp_a = dpnp.arange(4, dtype=dtype)
75-
dp_a = dp_a[:, dpnp.newaxis] + dp_a[dpnp.newaxis, :]
75+
dp_a = dpnp.add.outer(dp_a, dp_a)
7676
assert_equal(dpnp.flip(dp_a, 0), dpnp.flipud(dp_a))
7777

7878
np_a = numpy.arange(4, dtype=dtype)

tests/test_mathematical.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2875,3 +2875,45 @@ def test_bitwise_1array_input():
28752875
result = dpnp.add(1, x, dtype="f4")
28762876
expected = numpy.add(1, x_np, dtype="f4")
28772877
assert_dtype_allclose(result, expected)
2878+
2879+
2880+
@pytest.mark.parametrize(
2881+
"x_shape",
2882+
[
2883+
(),
2884+
(2),
2885+
(3, 4),
2886+
(3, 4, 5),
2887+
],
2888+
)
2889+
@pytest.mark.parametrize(
2890+
"y_shape",
2891+
[
2892+
(),
2893+
(2),
2894+
(3, 4),
2895+
(3, 4, 5),
2896+
],
2897+
)
2898+
def test_elemenwise_outer(x_shape, y_shape):
2899+
x_np = numpy.random.random(x_shape)
2900+
y_np = numpy.random.random(y_shape)
2901+
expected = numpy.multiply.outer(x_np, y_np)
2902+
2903+
x = dpnp.asarray(x_np)
2904+
y = dpnp.asarray(y_np)
2905+
result = dpnp.multiply.outer(x, y)
2906+
2907+
assert_dtype_allclose(result, expected)
2908+
2909+
result_outer = dpnp.outer(x, y)
2910+
assert dpnp.allclose(result.flatten(), result_outer.flatten())
2911+
2912+
2913+
def test_elemenwise_outer_scalar():
2914+
s = 5
2915+
x = dpnp.asarray([1, 2, 3])
2916+
y = dpnp.asarray(s)
2917+
expected = dpnp.add.outer(x, y)
2918+
result = dpnp.add.outer(x, s)
2919+
assert_dtype_allclose(result, expected)

tests/test_outer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,32 +42,25 @@ class TestScalarOuter(unittest.TestCase):
4242
@testing.for_all_dtypes()
4343
@testing.numpy_cupy_allclose(type_check=False)
4444
def test_first_is_scalar(self, xp, dtype):
45-
scalar = xp.int64(4)
45+
scalar = 4
4646
a = xp.arange(5**3, dtype=dtype).reshape(5, 5, 5)
4747
return xp.outer(scalar, a)
4848

4949
@testing.for_all_dtypes()
5050
@testing.numpy_cupy_allclose(type_check=False)
5151
def test_second_is_scalar(self, xp, dtype):
52-
scalar = xp.int32(7)
52+
scalar = 7
5353
a = xp.arange(5**3, dtype=dtype).reshape(5, 5, 5)
5454
return xp.outer(a, scalar)
5555

56-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
57-
@testing.numpy_cupy_array_equal()
58-
def test_both_inputs_as_scalar(self, xp):
59-
a = xp.int64(4)
60-
b = xp.int32(17)
61-
return xp.outer(a, b)
62-
6356

6457
class TestListOuter(unittest.TestCase):
6558
def test_list(self):
6659
a = np.arange(27).reshape(3, 3, 3)
6760
b: list[list[list[int]]] = a.tolist()
6861
dp_a = dp.array(a)
6962

70-
with assert_raises(NotImplementedError):
63+
with assert_raises(TypeError):
7164
dp.outer(b, dp_a)
7265
dp.outer(dp_a, b)
7366
dp.outer(b, b)

0 commit comments

Comments
 (0)