Skip to content

Commit 1096977

Browse files
committed
ENH: add broadcast_shapes for array_api_version >= 2025.12
1 parent c303adc commit 1096977

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from ._data_type_functions import (
7474
astype,
7575
broadcast_arrays,
76+
broadcast_shapes,
7677
broadcast_to,
7778
can_cast,
7879
finfo,
@@ -84,6 +85,7 @@
8485
__all__ += [
8586
"astype",
8687
"broadcast_arrays",
88+
"broadcast_shapes",
8789
"broadcast_to",
8890
"can_cast",
8991
"finfo",

array_api_strict/_data_type_functions.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
_signed_integer_dtypes,
1717
_unsigned_integer_dtypes,
1818
)
19-
from ._flags import get_array_api_strict_flags
19+
from ._flags import get_array_api_strict_flags, requires_api_version
2020

2121

2222
# Note: astype is a function, not an array method as in NumPy.
@@ -62,6 +62,16 @@ def broadcast_arrays(*arrays: Array) -> list[Array]:
6262
]
6363

6464

65+
@requires_api_version("2025.12")
66+
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]:
67+
"""
68+
Array API compatible wrapper for :py:func:`np.broadcast_shapes <numpy.broadcast_shapes>`.
69+
70+
See its docstring for more information.
71+
"""
72+
return np.broadcast_shapes(*shapes)
73+
74+
6575
def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array:
6676
"""
6777
Array API compatible wrapper for :py:func:`np.broadcast_to <numpy.broadcast_to>`.

0 commit comments

Comments
 (0)