Skip to content

Commit c3c00a7

Browse files
authored
Remove comment about take_along_axis not being in array API (#809)
1 parent e4e990c commit c3c00a7

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

cubed/backend_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# namespace variable, and defaults to array_api_compat.nump, unless it
88
# is overridden by an environment variable.
99
# It must be compatible with the Python Array API standard, although
10-
# some extra functions are used too (nan functions, take_along_axis),
10+
# some extra functions are used too (e.g. nan functions),
1111
# which array_api_compat provides, but other Array API implementations
1212
# may not.
1313

cubed/core/ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,6 @@ def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False, split_every=None
14331433

14341434
def _arg_map_func(a, axis, arg_func=None, size=None, block_id=None):
14351435
i = arg_func(a, axis=axis, keepdims=True)
1436-
# note that the array API doesn't have take_along_axis, so this may fail
14371436
v = nxp.take_along_axis(a, i, axis=axis)
14381437
# add block offset to i so it is absolute index within whole array
14391438
offset = block_id[axis] * size
@@ -1454,7 +1453,6 @@ def _arg_combine(a, arg_func=None, **kwargs):
14541453

14551454
# find indexes of values in v and apply to i and v
14561455
vi = arg_func(v, axis=axis, **kwargs)
1457-
# note that the array API doesn't have take_along_axis, so this may fail
14581456
i_combined = nxp.take_along_axis(i, vi, axis=axis)
14591457
v_combined = nxp.take_along_axis(v, vi, axis=axis)
14601458
return {"i": i_combined, "v": v_combined}

0 commit comments

Comments
 (0)