Skip to content

Commit 827ad32

Browse files
committed
ENH: broadcast_shapes
1 parent 27b0bf2 commit 827ad32

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
99
at
1010
atleast_nd
11+
broadcast_shapes
1112
cov
1213
create_diagonal
1314
expand_dims

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
5656
"array-api": ("https://data-apis.org/array-api/draft", None),
57+
"numpy": ("https://numpy.org/doc/stable", None),
5758
"jax": ("https://jax.readthedocs.io/en/latest", None),
5859
}
5960

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._lib._at import at
55
from ._lib._funcs import (
66
atleast_nd,
7+
broadcast_shapes,
78
cov,
89
create_diagonal,
910
expand_dims,
@@ -20,6 +21,7 @@
2021
"__version__",
2122
"at",
2223
"atleast_nd",
24+
"broadcast_shapes",
2325
"cov",
2426
"create_diagonal",
2527
"expand_dims",

src/array_api_extra/_lib/_funcs.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
__all__ = [
1919
"atleast_nd",
20+
"broadcast_shapes",
2021
"cov",
2122
"create_diagonal",
2223
"expand_dims",
@@ -71,6 +72,66 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
7172
return x
7273

7374

75+
def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]:
76+
"""
77+
Compute the shape of the broadcasted arrays.
78+
79+
Duplicates :func:`numpy.broadcast_shapes`, with additional support for
80+
None and NaN sizes.
81+
82+
This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
83+
without needing to worry about the backend potentially deep copying
84+
the arrays.
85+
86+
Parameters
87+
----------
88+
*shapes : tuple[int | None, ...]
89+
Shapes of the arrays to broadcast.
90+
91+
Returns
92+
-------
93+
tuple[int | None, ...]
94+
The shape of the broadcasted arrays.
95+
96+
See Also
97+
--------
98+
numpy.broadcast_shapes : Equivalent NumPy function.
99+
100+
Notes
101+
-----
102+
This function accepts the Array API's ``None`` for unknown sizes,
103+
as well as Dask's non-standard ``math.nan``.
104+
Regardless of input, the output always contains ``None`` for unknown sizes.
105+
106+
Examples
107+
--------
108+
>>> import array_api_extra as xpx
109+
>>> xpx.broadcast_shapes((2, 3), (2, 1))
110+
(2, 3)
111+
>>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
112+
(4, 2, 3)
113+
"""
114+
if not shapes:
115+
return () # Match numpy output
116+
117+
ndim = max(len(shape) for shape in shapes)
118+
out: list[int | None] = []
119+
for axis in range(-ndim, 0):
120+
sizes = {shape[axis] for shape in shapes if axis >= -len(shape)}
121+
# Dask uses NaN for unknown shape, which predates the Array API spec for None
122+
none_size = None in sizes or math.nan in sizes
123+
sizes -= {1, None, math.nan}
124+
if len(sizes) > 1:
125+
msg = (
126+
"shape mismatch: objects cannot be broadcast to a single shape: "
127+
f"{shapes}."
128+
)
129+
raise ValueError(msg)
130+
out.append(None if none_size else cast(int, sizes.pop()) if sizes else 1)
131+
132+
return tuple(out)
133+
134+
74135
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
75136
"""
76137
Estimate a covariance matrix.

tests/test_funcs.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import math
23
import warnings
34
from types import ModuleType
45

@@ -8,6 +9,7 @@
89
from array_api_extra import (
910
at,
1011
atleast_nd,
12+
broadcast_shapes,
1113
cov,
1214
create_diagonal,
1315
expand_dims,
@@ -113,6 +115,62 @@ def test_xp(self, xp: ModuleType):
113115
xp_assert_equal(y, xp.ones((1,)))
114116

115117

118+
class TestBroadcastShapes:
119+
@pytest.mark.parametrize(
120+
"args",
121+
[
122+
(),
123+
((),),
124+
((), ()),
125+
((1,),),
126+
((1,), (1,)),
127+
((2,), (1,)),
128+
((3, 1, 4), (2, 1)),
129+
((1, 1, 4), (2, 1)),
130+
((1,), ()),
131+
((), (2,), ()),
132+
((0,),),
133+
((0,), (1,)),
134+
((2, 0), (1, 1)),
135+
((2, 0, 3), (2, 1, 1)),
136+
],
137+
)
138+
def test_simple(self, args: tuple[tuple[int, ...], ...]):
139+
expect = np.broadcast_shapes(*args)
140+
actual = broadcast_shapes(*args)
141+
assert actual == expect
142+
143+
@pytest.mark.parametrize(
144+
"args",
145+
[
146+
((2,), (3,)),
147+
((2, 3), (1, 2)),
148+
((2,), (0,)),
149+
((2, 0, 2), (1, 3, 1)),
150+
],
151+
)
152+
def test_fail(self, args: tuple[tuple[int, ...], ...]):
153+
match = "cannot be broadcast to a single shape"
154+
with pytest.raises(ValueError, match=match):
155+
_ = np.broadcast_shapes(*args)
156+
with pytest.raises(ValueError, match=match):
157+
_ = broadcast_shapes(*args)
158+
159+
@pytest.mark.parametrize(
160+
"args",
161+
[
162+
((None,), (None,)),
163+
((math.nan,), (None,)),
164+
((1, None, 2, 4), (2, 3, None, 1), (2, None, None, 4)),
165+
((1, math.nan, 2), (4, 2, 3, math.nan), (4, 2, None, None)),
166+
],
167+
)
168+
def test_none(self, args: tuple[tuple[float | None, ...], ...]):
169+
expect = args[-1]
170+
actual = broadcast_shapes(*args[:-1])
171+
assert actual == expect
172+
173+
116174
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
117175
class TestCov:
118176
def test_basic(self, xp: ModuleType):

0 commit comments

Comments
 (0)