Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

at
atleast_nd
broadcast_shapes
cov
create_diagonal
expand_dims
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"array-api": ("https://data-apis.org/array-api/draft", None),
"numpy": ("https://numpy.org/doc/stable", None),
"jax": ("https://jax.readthedocs.io/en/latest", None),
}

Expand Down
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._lib._at import at
from ._lib._funcs import (
atleast_nd,
broadcast_shapes,
cov,
create_diagonal,
expand_dims,
Expand All @@ -20,6 +21,7 @@
"__version__",
"at",
"atleast_nd",
"broadcast_shapes",
"cov",
"create_diagonal",
"expand_dims",
Expand Down
62 changes: 62 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

__all__ = [
"atleast_nd",
"broadcast_shapes",
"cov",
"create_diagonal",
"expand_dims",
Expand Down Expand Up @@ -71,6 +72,67 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
return x


def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]:
"""
Compute the shape of the broadcasted arrays.

Duplicates :func:`numpy.broadcast_shapes`, with additional support for
None and NaN sizes.

This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
without needing to worry about the backend potentially deep copying
the arrays.

Parameters
----------
*shapes : tuple[int | None, ...]
Shapes of the arrays to broadcast.

Returns
-------
tuple[int | None, ...]
The shape of the broadcasted arrays.

See Also
--------
numpy.broadcast_shapes : Equivalent NumPy function.
array_api.broadcast_arrays : Function to broadcast actual arrays.

Notes
-----
This function accepts the Array API's ``None`` for unknown sizes,
as well as Dask's non-standard ``math.nan``.
Regardless of input, the output always contains ``None`` for unknown sizes.

Examples
--------
>>> import array_api_extra as xpx
>>> xpx.broadcast_shapes((2, 3), (2, 1))
(2, 3)
>>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
(4, 2, 3)
"""
if not shapes:
return () # Match numpy output

ndim = max(len(shape) for shape in shapes)
out: list[int | None] = []
for axis in range(-ndim, 0):
sizes = {shape[axis] for shape in shapes if axis >= -len(shape)}
# Dask uses NaN for unknown shape, which predates the Array API spec for None
none_size = None in sizes or math.nan in sizes
sizes -= {1, None, math.nan}
if len(sizes) > 1:
msg = (
"shape mismatch: objects cannot be broadcast to a single shape: "
f"{shapes}."
)
raise ValueError(msg)
out.append(None if none_size else cast(int, sizes.pop()) if sizes else 1)

return tuple(out)


def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Estimate a covariance matrix.
Expand Down
58 changes: 58 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import math
import warnings
from types import ModuleType

Expand All @@ -8,6 +9,7 @@
from array_api_extra import (
at,
atleast_nd,
broadcast_shapes,
cov,
create_diagonal,
expand_dims,
Expand Down Expand Up @@ -113,6 +115,62 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(y, xp.ones((1,)))


class TestBroadcastShapes:
@pytest.mark.parametrize(
"args",
[
(),
((),),
((), ()),
((1,),),
((1,), (1,)),
((2,), (1,)),
((3, 1, 4), (2, 1)),
((1, 1, 4), (2, 1)),
((1,), ()),
((), (2,), ()),
((0,),),
((0,), (1,)),
((2, 0), (1, 1)),
((2, 0, 3), (2, 1, 1)),
],
)
def test_simple(self, args: tuple[tuple[int, ...], ...]):
expect = np.broadcast_shapes(*args)
actual = broadcast_shapes(*args)
assert actual == expect

@pytest.mark.parametrize(
"args",
[
((2,), (3,)),
((2, 3), (1, 2)),
((2,), (0,)),
((2, 0, 2), (1, 3, 1)),
],
)
def test_fail(self, args: tuple[tuple[int, ...], ...]):
match = "cannot be broadcast to a single shape"
with pytest.raises(ValueError, match=match):
_ = np.broadcast_shapes(*args)
with pytest.raises(ValueError, match=match):
_ = broadcast_shapes(*args)

@pytest.mark.parametrize(
"args",
[
((None,), (None,)),
((math.nan,), (None,)),
((1, None, 2, 4), (2, 3, None, 1), (2, None, None, 4)),
((1, math.nan, 2), (4, 2, 3, math.nan), (4, 2, None, None)),
],
)
def test_none(self, args: tuple[tuple[float | None, ...], ...]):
expect = args[-1]
actual = broadcast_shapes(*args[:-1])
assert actual == expect


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
class TestCov:
def test_basic(self, xp: ModuleType):
Expand Down