diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 116df25..7a7e9b6 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -73,6 +73,7 @@ from ._data_type_functions import ( astype, broadcast_arrays, + broadcast_shapes, broadcast_to, can_cast, finfo, @@ -84,6 +85,7 @@ __all__ += [ "astype", "broadcast_arrays", + "broadcast_shapes", "broadcast_to", "can_cast", "finfo", diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 82d438f..e86bc69 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -16,7 +16,7 @@ _signed_integer_dtypes, _unsigned_integer_dtypes, ) -from ._flags import get_array_api_strict_flags +from ._flags import get_array_api_strict_flags, requires_api_version # Note: astype is a function, not an array method as in NumPy. @@ -62,6 +62,16 @@ def broadcast_arrays(*arrays: Array) -> list[Array]: ] +@requires_api_version("2025.12") +def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: + """ + Array API compatible wrapper for :py:func:`np.broadcast_shapes `. + + See its docstring for more information. + """ + return np.broadcast_shapes(*shapes) + + def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: """ Array API compatible wrapper for :py:func:`np.broadcast_to `.