|
12 | 12 | from ._utils._compat import (
|
13 | 13 | array_namespace,
|
14 | 14 | is_array_api_obj,
|
15 |
| - is_array_api_strict_namespace, |
16 | 15 | is_dask_array,
|
17 | 16 | is_jax_array,
|
| 17 | + is_lazy_array, |
18 | 18 | is_torch_array,
|
19 | 19 | is_writeable_array,
|
20 | 20 | )
|
@@ -390,14 +390,37 @@ def _op(
|
390 | 390 | # We first create the mask for x
|
391 | 391 | u_idx, _ = xp.unique_inverse(idx)
|
392 | 392 | # Convert negative indices to positive, otherwise they won't get matched
|
393 |
| - u_idx = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx) |
394 |
| - x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx, axis=-1) |
| 393 | + u_idx_pos = xp.where(u_idx < 0, x.shape[0] + u_idx, u_idx) |
| 394 | + # Check for out of bounds indices |
| 395 | + oob_index = None |
| 396 | + if is_lazy_array(u_idx_pos): |
| 397 | + pass # Lazy arrays cannot check for out of bounds indices |
| 398 | + elif xp.any(pos_oob := u_idx_pos >= x.shape[0]): |
| 399 | + first_oob_idx = xp.argmax(xp.astype(pos_oob, xp.int32)) |
| 400 | + oob_index = int(u_idx[first_oob_idx]) |
| 401 | + elif xp.any(neg_oob := u_idx_pos < 0): |
| 402 | + first_oob_idx = xp.argmax(xp.astype(neg_oob, xp.int32)) |
| 403 | + oob_index = int(u_idx[first_oob_idx]) |
| 404 | + if oob_index is not None: |
| 405 | + msg = ( |
| 406 | + f"index {oob_index} is out of bounds for array of " |
| 407 | + f"shape {x.shape}" |
| 408 | + ) |
| 409 | + raise IndexError(msg) from e |
| 410 | + |
| 411 | + x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1) |
395 | 412 | # If y is a scalar or 0D, we are done
|
396 | 413 | if not is_array_api_obj(y) or y.ndim == 0:
|
397 | 414 | x[x_mask] = y
|
398 | 415 | return x
|
| 416 | + if y.shape[0] != idx.shape[0]: |
| 417 | + msg = ( |
| 418 | + f"shape mismatch: value array of shape {y.shape} could not be " |
| 419 | + f"broadcast to indexing result of shape {idx.shape}" |
| 420 | + ) |
| 421 | + raise ValueError(msg) from e |
399 | 422 | # If not, create a mask for y. Get last occurrence of each unique index
|
400 |
| - cmp = u_idx[:, None] == u_idx[None, :] |
| 423 | + cmp = u_idx_pos[:, None] == u_idx_pos[None, :] |
401 | 424 | # Ignore later matches
|
402 | 425 | lower_tri_mask = (
|
403 | 426 | xp.arange(y.shape[0])[:, None] >= xp.arange(y.shape[0])[None, :]
|
|
0 commit comments