Skip to content

Commit d738cba

Browse files
committed
ENH: add broadcast_shapes if array_api_version is at least 2025.12
1 parent c303adc commit d738cba

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

array_api_strict/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
"result_type",
9393
]
9494

95+
9596
from ._dtypes import (
9697
int8,
9798
int16,
@@ -335,6 +336,11 @@
335336
'__version__',
336337
]
337338

339+
if get_array_api_strict_flags()['api_version'] >= '2025.12':
340+
from ._data_type_functions import broadcast_shapes
341+
__all__ += ["broadcast_shapes"]
342+
343+
338344
try:
339345
from ._version import __version__ # type: ignore[import-not-found,unused-ignore]
340346
except ImportError:

array_api_strict/_data_type_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ def broadcast_arrays(*arrays: Array) -> list[Array]:
6262
]
6363

6464

65+
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]:
66+
"""
67+
Array API compatible wrapper for :py:func:`np.broadcast_shapes <numpy.broadcast_shapes>`.
68+
69+
See its docstring for more information.
70+
"""
71+
return np.broadcast_shapes(*shapes)
72+
73+
6574
def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array:
6675
"""
6776
Array API compatible wrapper for :py:func:`np.broadcast_to <numpy.broadcast_to>`.

0 commit comments

Comments
 (0)