Skip to content

Commit d5824c7

Browse files
committed
review
1 parent 58db9c1 commit d5824c7

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

array_api_strict/tests/test_array_object.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,18 @@ def test_validate_index():
100100
assert_raises(IndexError, lambda: a[:])
101101
assert_raises(IndexError, lambda: a[idx])
102102

103+
class DummyIndex:
104+
def __init__(self, x):
105+
self.x = x
106+
def __index__(self):
107+
return self.x
108+
103109

104110
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])
105111
@pytest.mark.parametrize(
106112
"integer_index",
107113
[
108-
1,
114+
0,
109115
np.int8(0),
110116
np.uint8(0),
111117
np.int16(0),
@@ -114,14 +120,15 @@ def test_validate_index():
114120
np.uint32(0),
115121
np.int64(0),
116122
np.uint64(0),
123+
DummyIndex(0),
117124
],
118125
)
119126
def test_indexing_ints(integer_index, device):
120127
# Ensure indexing with different integer types works on all Devices.
121128
device = None if device is None else Device(device)
122129

123130
a = arange(5, device=device)
124-
a[integer_index]
131+
assert a[(integer_index,)] == a[integer_index] == a[0]
125132

126133

127134
@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"])

0 commit comments

Comments
 (0)