@@ -150,6 +150,19 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
150
150
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
151
151
Array([125], dtype=int32)
152
152
153
+ For frameworks that don't support fancy indexing by default, e.g. array-api-strict,
154
+ we implement a workaround for 1D integer indices and ``xpx.at().set``. Assignments
155
+ with multiple occurences of the same index always choose the last occurence. This is
156
+ consistent with numpy's behaviour, e.g.::
157
+
158
+ >>> import numpy as np
159
+ >>> import array_api_strict as xp
160
+ >>> import array_api_extra as xpx
161
+ >>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3]))
162
+ array([3])
163
+ >>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3]))
164
+ Array([3], dtype=array_api_strict.int64)
165
+
153
166
See Also
154
167
--------
155
168
jax.numpy.ndarray.at : Equivalent array method in JAX.
@@ -354,39 +367,50 @@ def _op(
354
367
if is_torch_array (y ):
355
368
y = xp .astype (y , x .dtype , copy = False )
356
369
357
- # Work around lack of fancy indexing __setitem__ support in array-api-strict.
358
- if (
359
- is_array_api_strict_namespace (xp )
360
- and is_array_api_obj (idx )
361
- and xp .isdtype (idx .dtype , "integral" )
362
- and out_of_place_op is None # only use for set()
363
- ):
364
- # Vectorize the operation using boolean indexing
365
- # For non-unique indices, take the last occurrence. This requires creating
366
- # masks for x and y that create matching shapes.
367
- unique_indices , _ = xp .unique_inverse (idx )
368
- x_mask = xp .any (xp .arange (x .shape [0 ])[..., None ] == unique_indices , axis = - 1 )
369
- # Get last occurrence of each unique index
370
- cmp = unique_indices [:, None ] == unique_indices [None , :]
371
- # Ignore later matches
372
- lower_tri_mask = (
373
- xp .arange (y .shape [0 ])[:, None ] >= xp .arange (y .shape [0 ])[None , :]
374
- )
375
- masked_cmp = cmp & lower_tri_mask
376
- # For each position i, count how many matches occurred before i
377
- prior_matches = xp .sum (xp .astype (masked_cmp , xp .int32 ), axis = - 1 )
378
- # Last occurrence has highest match count
379
- y_mask = prior_matches == xp .max (prior_matches , axis = - 1 )
380
- # Apply the operation only to last occurrences
381
- x [x_mask ] = y [y_mask ]
382
- return x
383
-
384
370
# Backends without boolean indexing (other than JAX) crash here
385
371
if in_place_op : # add(), subtract(), ...
386
372
x [idx ] = in_place_op (x [idx ], y )
387
- else : # set()
373
+ return x
374
+ # set()
375
+ try : # We first try to use the backend's __setitem__ if available
388
376
x [idx ] = y
389
- return x
377
+ return x
378
+ except IndexError as e :
379
+ if "Fancy indexing" not in str (e ): # Avoid masking other index errors
380
+ raise e
381
+ # Work around lack of fancy indexing __setitem__
382
+ if (
383
+ is_array_api_obj (idx )
384
+ and xp .isdtype (idx .dtype , "integral" )
385
+ and idx .ndim == 1
386
+ ):
387
+ # Vectorize the operation using boolean indexing
388
+ # For non-unique indices, take the last occurrence. This requires
389
+ # masks for x and y that create matching shapes.
390
+ # We first create the mask for x
391
+ u_idx , _ = xp .unique_inverse (idx )
392
+ # Convert negative indices to positive, otherwise they won't get matched
393
+ u_idx = xp .where (u_idx < 0 , x .shape [0 ] + u_idx , u_idx )
394
+ x_mask = xp .any (xp .arange (x .shape [0 ])[..., None ] == u_idx , axis = - 1 )
395
+ # If y is a scalar or 0D, we are done
396
+ if not is_array_api_obj (y ) or y .ndim == 0 :
397
+ x [x_mask ] = y
398
+ return x
399
+ # If not, create a mask for y. Get last occurrence of each unique index
400
+ cmp = u_idx [:, None ] == u_idx [None , :]
401
+ # Ignore later matches
402
+ lower_tri_mask = (
403
+ xp .arange (y .shape [0 ])[:, None ] >= xp .arange (y .shape [0 ])[None , :]
404
+ )
405
+ masked_cmp = cmp & lower_tri_mask
406
+ # For each position i, count how many matches occurred before i
407
+ prior_matches = xp .sum (xp .astype (masked_cmp , xp .int32 ), axis = - 1 )
408
+ # Last occurrence has highest match count
409
+ y_mask = prior_matches == xp .max (prior_matches , axis = - 1 )
410
+ # Apply the operation only to last occurrences
411
+ x [x_mask ] = y [y_mask ]
412
+ return x
413
+ raise e
390
414
391
415
def set (
392
416
self ,
0 commit comments