Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def __getitem__(
devices = {self.device}
if isinstance(key, tuple):
devices.update(
[subkey.device for subkey in key if hasattr(subkey, "device")]
[subkey.device for subkey in key if isinstance(subkey, Array)]
)
if len(devices) > 1:
raise ValueError(
Expand Down
24 changes: 24 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,30 @@ def test_validate_index():
assert_raises(IndexError, lambda: a[idx])


@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
@pytest.mark.parametrize(
"integer_index",
[
1,
np.bool(1),
np.int8(0),
np.uint8(0),
np.int16(0),
np.uint16(0),
np.int32(0),
np.uint32(0),
np.int64(0),
np.uint64(0),
],
)
def test_indexing_ints(integer_index, device):
# Ensure indexing with different integer types works on all Devices.
device = None if device is None else Device(device)

a = arange(5, device=device)
a[integer_index]


@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
def test_indexing_arrays(device):
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
Expand Down
Loading