Skip to content

Commit c8ae92b

Browse files
committed
iter
1 parent dc71844 commit c8ae92b

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

array_api_strict/_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def __getitem__(
722722
devices = {self.device}
723723
if isinstance(key, tuple):
724724
devices.update(
725-
[subkey.device for subkey in key if hasattr(subkey, "device")]
725+
[subkey.device for subkey in key if isinstance(subkey, Array)]
726726
)
727727
if len(devices) > 1:
728728
raise ValueError(

array_api_strict/tests/test_array_object.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,31 @@ def test_validate_index():
101101
assert_raises(IndexError, lambda: a[idx])
102102

103103

104+
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
105+
@pytest.mark.parametrize(
106+
"integer_index",
107+
[
108+
1,
109+
np.bool(1),
110+
np.int8(0),
111+
np.uint8(0),
112+
np.int16(0),
113+
np.uint16(0),
114+
np.int32(0),
115+
np.uint32(0),
116+
np.int64(0),
117+
np.uint64(0),
118+
2,
119+
],
120+
)
121+
def test_indexing_ints(integer_index, device):
122+
# Ensure indexing with different integer types works on all Devices.
123+
device = None if device is None else Device(device)
124+
125+
a = arange(5, device=device)
126+
a[integer_index]
127+
128+
104129
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
105130
def test_indexing_arrays(device):
106131
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed

0 commit comments

Comments
 (0)