diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index 64c51ce..b214932 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -304,7 +304,7 @@ def linspace( ) -def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array]: +def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.meshgrid `. @@ -327,10 +327,12 @@ def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array else: device = None - return [ + typ = list if get_array_api_strict_flags()['api_version'] < '2025.12' else tuple + + return typ( Array._new(array, device=device) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing) - ] + ) def ones( diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index 82d438f..1fc3ac2 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -49,7 +49,7 @@ def astype( return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device) -def broadcast_arrays(*arrays: Array) -> list[Array]: +def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.broadcast_arrays `. @@ -57,9 +57,11 @@ def broadcast_arrays(*arrays: Array) -> list[Array]: """ from ._array_object import Array - return [ + typ = list if get_array_api_strict_flags()['api_version'] < '2025.12' else tuple + + return typ( Array._new(array, device=arrays[0].device) for array in np.broadcast_arrays(*[a._array for a in arrays]) - ] + ) def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: diff --git a/array_api_strict/_info.py b/array_api_strict/_info.py index 0eb6696..b2be132 100644 --- a/array_api_strict/_info.py +++ b/array_api_strict/_info.py @@ -130,5 +130,5 @@ def dtypes( raise ValueError(f"unsupported kind: {kind!r}") @requires_api_version('2023.12') - def devices(self) -> list[Device]: - return list(ALL_DEVICES) + def devices(self) -> tuple[Device]: + return tuple(ALL_DEVICES)