Skip to content

Commit c746621

Browse files
committed
cupy.broadcast_arrays: make it return a tuple
1 parent 9bce442 commit c746621

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
139139
return cp.take_along_axis(x, indices, axis=axis)
140140

141141

142+
# https://github.com/cupy/cupy/pull/9582
143+
def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]:
144+
return tuple(cp.broadcast_arrays(*arrays))
145+
146+
142147
# These functions are completely new here. If the library already has them
143148
# (i.e., numpy 2.0), use the library version instead of our wrapper.
144149
if hasattr(cp, 'vecdot'):
@@ -161,7 +166,8 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
161166
'atan2', 'atanh', 'bitwise_left_shift',
162167
'bitwise_invert', 'bitwise_right_shift',
163168
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
164-
'ceil', 'floor', 'trunc', 'take_along_axis']
169+
'ceil', 'floor', 'trunc', 'take_along_axis',
170+
'broadcast_arrays',]
165171

166172

167173
def __dir__() -> list[str]:

0 commit comments

Comments
 (0)