Skip to content

Commit 001ace3

Browse files
committed
fix
1 parent a8f567a commit 001ace3

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

array_api_strict/_array_object.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,13 @@ def __getitem__(
698698
# docstring of _validate_index
699699
self._validate_index(key)
700700
if isinstance(key, Array):
701+
key = (key,)
702+
if isinstance(key, tuple):
701703
# Indexing self._array with array_api_strict arrays can be erroneous
702-
key = key._array
704+
# e.g., when using non-default device
705+
key = tuple(
706+
subkey._array if isinstance(subkey, Array) else subkey for subkey in key
707+
)
703708
res = self._array.__getitem__(key)
704709
return self._new(res, device=self.device)
705710

array_api_strict/tests/test_array_object.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pytest
77

8-
from .. import ones, arange, reshape, asarray, result_type, all, equal
8+
from .. import ones, arange, reshape, asarray, result_type, all, equal, stack
99
from .._array_object import Array, CPU_DEVICE, Device
1010
from .._dtypes import (
1111
_all_dtypes,
@@ -101,33 +101,37 @@ def test_validate_index():
101101
assert_raises(IndexError, lambda: a[idx])
102102

103103

104-
def test_indexing_arrays():
104+
# @pytest.mark.parametrize("device", ["CPU_DEVICE", "device1", "device2"])
105+
def test_indexing_arrays(device='device1'):
105106
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
107+
device = Device(device)
106108

107109
# 1D array
108110
a = arange(5)
109-
idx = asarray([1, 0, 1, 2, -1])
111+
idx = asarray([1, 0, 1, 2, -1], device=device)
110112
a_idx = a[idx]
111113

112-
a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
114+
a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])])
113115
assert all(a_idx == a_idx_loop)
116+
assert a_idx.shape == idx.shape
114117

115118
# setitem with arrays is not allowed
116119
with assert_raises(IndexError):
117120
a[idx] = 42
118121

119122
# mixed array and integer indexing
120-
a = reshape(arange(3*4), (3, 4))
121-
idx = asarray([1, 0, 1, 2, -1])
123+
a = reshape(arange(3*4, device=device), (3, 4))
124+
idx = asarray([1, 0, 1, 2, -1], device=device)
122125
a_idx = a[idx, 1]
123-
124-
a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
126+
a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])])
125127
assert all(a_idx == a_idx_loop)
128+
assert a_idx.shape == idx.shape
126129

127130
# index with two arrays
128131
a_idx = a[idx, idx]
129-
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
132+
a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])])
130133
assert all(a_idx == a_idx_loop)
134+
assert a_idx.shape == a_idx.shape
131135

132136
# setitem with arrays is not allowed
133137
with assert_raises(IndexError):

0 commit comments

Comments
 (0)