Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,13 @@ def __getitem__(
# docstring of _validate_index
self._validate_index(key)
if isinstance(key, Array):
key = (key,)
if isinstance(key, tuple):
# Indexing self._array with array_api_strict arrays can be erroneous
key = key._array
# e.g., when using non-default device
key = tuple(
subkey._array if isinstance(subkey, Array) else subkey for subkey in key
)
res = self._array.__getitem__(key)
return self._new(res, device=self.device)

Expand Down
22 changes: 13 additions & 9 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from .. import ones, arange, reshape, asarray, result_type, all, equal
from .. import ones, arange, reshape, asarray, result_type, all, equal, stack
from .._array_object import Array, CPU_DEVICE, Device
from .._dtypes import (
_all_dtypes,
Expand Down Expand Up @@ -101,33 +101,37 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[idx])


def test_indexing_arrays():
# @pytest.mark.parametrize("device", ["CPU_DEVICE", "device1", "device2"])
def test_indexing_arrays(device='device1'):
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
device = Device(device)

# 1D array
a = arange(5)
idx = asarray([1, 0, 1, 2, -1])
idx = asarray([1, 0, 1, 2, -1], device=device)
a_idx = a[idx]

a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == idx.shape

# setitem with arrays is not allowed
with assert_raises(IndexError):
a[idx] = 42

# mixed array and integer indexing
a = reshape(arange(3*4), (3, 4))
idx = asarray([1, 0, 1, 2, -1])
a = reshape(arange(3*4, device=device), (3, 4))
idx = asarray([1, 0, 1, 2, -1], device=device)
a_idx = a[idx, 1]

a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == idx.shape

# index with two arrays
a_idx = a[idx, idx]
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])])
assert all(a_idx == a_idx_loop)
assert a_idx.shape == a_idx.shape

# setitem with arrays is not allowed
with assert_raises(IndexError):
Expand Down
Loading