diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 2e512fc8..badfe390 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -139,6 +139,11 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return cp.take_along_axis(x, indices, axis=axis) +# https://github.com/cupy/cupy/pull/9582 +def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: + return tuple(cp.broadcast_arrays(*arrays)) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -161,7 +166,8 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'ceil', 'floor', 'trunc', 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis', + 'broadcast_arrays',] def __dir__() -> list[str]: diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 78e48a33..aef10e85 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -333,4 +333,4 @@ def devices(self): __array_namespace_info__.dtypes """ - return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())] + return tuple(cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())) diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 2f39fc4b..3a7285d5 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -379,7 +379,7 @@ def dtypes( return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[Device]: + def devices(self) -> tuple[Device]: """ The devices supported by Dask. @@ -404,4 +404,4 @@ def devices(self) -> list[Device]: ['cpu', DASK_DEVICE] """ - return ["cpu", _DASK_DEVICE] + return ("cpu", _DASK_DEVICE) diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index c625c13e..9ba004da 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -332,7 +332,7 @@ def dtypes( return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[Device]: + def devices(self) -> tuple[Device]: """ The devices supported by NumPy. @@ -357,7 +357,7 @@ def devices(self) -> list[Device]: ['cpu'] """ - return ["cpu"] + return ("cpu",) __all__ = ["__array_namespace_info__"] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index e40183d8..2b1af799 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -706,9 +706,9 @@ def astype( return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: Array) -> list[Array]: +def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) - return [torch.broadcast_to(a, shape) for a in arrays] + return tuple(torch.broadcast_to(a, shape) for a in arrays) # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. @@ -881,10 +881,11 @@ def sign(x: Array, /) -> Array: return out -def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: - # enforce the default of 'xy' - # TODO: is the return type a list or a tuple - return list(torch.meshgrid(*arrays, indexing=indexing)) +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]: + # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it + # will be required to pass the indexing argument." + # Thus always pass it explicitly. + return torch.meshgrid(*arrays, indexing=indexing) __all__ = ['asarray', 'result_type', 'can_cast', diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 818e5d37..050c7846 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -366,4 +366,4 @@ def devices(self): break i += 1 - return devices + return tuple(devices)