From 2bee4e6196d23f11d9bceb649636d1b9746189ec Mon Sep 17 00:00:00 2001 From: "Komarova, Evseniia" Date: Wed, 6 Nov 2024 18:56:21 +0100 Subject: [PATCH] Add dpnp.broadcast_shapes implementation --- dpnp/dpnp_iface_manipulation.py | 36 +++++++++++++++++++++++++++++++++ dpnp/dpnp_iface_mathematical.py | 3 +-- tests/test_manipulation.py | 25 +++++++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 8023040c08a8..6e75caeb7f92 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -63,6 +63,7 @@ "atleast_2d", "atleast_3d", "broadcast_arrays", + "broadcast_shapes", "broadcast_to", "can_cast", "column_stack", @@ -967,6 +968,41 @@ def broadcast_arrays(*args, subok=False): return [dpnp_array._create_from_usm_ndarray(a) for a in usm_arrays] +def broadcast_shapes(*args): + """ + Broadcast the input shapes into a single shape. + + For full documentation refer to :obj:`numpy.broadcast_shapes`. + + Parameters + ---------- + *args : tuples of ints, or ints + The shapes to be broadcast against each other. + + Returns + ------- + tuple + Broadcasted shape. + + See Also + -------- + :obj:`dpnp.broadcast_arrays` : Broadcast any number of arrays against + each other. + :obj:`dpnp.broadcast_to` : Broadcast an array to a new shape. + + Examples + -------- + >>> import dpnp as np + >>> np.broadcast_shapes((1, 2), (3, 1), (3, 2)) + (3, 2) + >>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7)) + (5, 6, 7) + + """ + + return numpy.broadcast_shapes(*args) + + # pylint: disable=redefined-outer-name def broadcast_to(array, /, shape, subok=False): """ diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 9318f0af4dc3..f8b477a947ab 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -994,8 +994,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): a_shape = a.shape b_shape = b.shape - # TODO: replace with dpnp.broadcast_shapes once implemented - res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1]) + res_shape = dpnp.broadcast_shapes(a_shape[:-1], b_shape[:-1]) if a_shape[:-1] != res_shape: a = dpnp.broadcast_to(a, res_shape + (a_shape[-1],)) a_shape = a.shape diff --git a/tests/test_manipulation.py b/tests/test_manipulation.py index c137ec6a22a9..2bab4d754f9e 100644 --- a/tests/test_manipulation.py +++ b/tests/test_manipulation.py @@ -332,6 +332,31 @@ def test_no_copy(self): assert_array_equal(b, a) +class TestBroadcast: + @pytest.mark.parametrize( + "shape", + [ + [(1,), (3,)], + [(1, 3), (3, 3)], + [(3, 1), (3, 3)], + [(1, 3), (3, 1)], + [(1, 1), (3, 3)], + [(1, 1), (1, 3)], + [(1, 1), (3, 1)], + [(1, 0), (0, 0)], + [(0, 1), (0, 0)], + [(1, 0), (0, 1)], + [(1, 1), (0, 0)], + [(1, 1), (1, 0)], + [(1, 1), (0, 1)], + ], + ) + def test_broadcast_shapes(self, shape): + expected = numpy.broadcast_shapes(*shape) + result = dpnp.broadcast_shapes(*shape) + assert_equal(result, expected) + + class TestDelete: @pytest.mark.parametrize( "obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]