Skip to content

Commit ad1c9be

Browse files
committed
ENH: only test with 1D index arrays, make test unvecorized
1 parent d5d3080 commit ad1c9be

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

array_api_tests/test_array_object.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,23 @@ def test_setitem_masking(shape, data):
242242
)
243243

244244

245-
@given(shape=hh.shapes(), data=st.data())
246-
def test_getitem_arrays_and_ints(shape, data):
245+
@pytest.mark.unvectorized
246+
@given(shape=hh.shapes(min_dims=2), data=st.data())
247+
def test_getitem_arrays_and_ints_1(shape, data):
248+
# min_dims=2 : test multidim `x` arrays
249+
# index arrays are all 1D
250+
_test_getitem_arrays_and_ints_1D(shape, data)
251+
252+
253+
@pytest.mark.unvectorized
254+
@given(shape=hh.shapes(min_dims=1), data=st.data())
255+
def test_getitem_arrays_and_ints_2(shape, data):
256+
# min_dims=1 : favor 1D `x` arrays
257+
# index arrays are all 1D
258+
_test_getitem_arrays_and_ints_1D(shape, data)
259+
260+
261+
def _test_getitem_arrays_and_ints_1D(shape, data):
247262
assume((len(shape) > 0) and all(sh > 0 for sh in shape))
248263

249264
dtype = xp.int32
@@ -254,10 +269,12 @@ def test_getitem_arrays_and_ints(shape, data):
254269
arr_index = [data.draw(st.booleans()) for _ in range(len(shape))]
255270
assume(sum(arr_index) > 0)
256271

257-
# draw shapes for index arrays
272+
# draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
258273
if sum(arr_index) > 0:
259274
index_shapes = data.draw(
260-
hh.mutually_broadcastable_shapes(sum(arr_index), min_dims=1, min_side=1)
275+
hh.mutually_broadcastable_shapes(
276+
sum(arr_index), min_dims=1, max_dims=1, min_side=1
277+
)
261278
)
262279
index_shapes = list(index_shapes)
263280

@@ -279,19 +296,18 @@ def test_getitem_arrays_and_ints(shape, data):
279296
# draw an integer
280297
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
281298

282-
283-
print(f"??? {x.shape = } {key = } -- {[k if isinstance(k, int) else k.shape for k in key]}")
299+
print(f"??? {x.shape = } {key = }")
284300

285301
key = tuple(key)
286302
out = x[key]
287303

288-
# XXX: how to properly check
289-
import numpy as np
290-
x_np = np.asarray(x)
291-
out_np = np.asarray(out)
292-
key_np = tuple(k if isinstance(k, int) else np.asarray(k) for k in key)
304+
arrays = [xp.asarray(k) for k in key]
305+
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
306+
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
293307

294-
np.testing.assert_equal(out_np, x_np[key_np])
308+
for idx in sh.ndindex(bcast_shape):
309+
tpl = tuple(k[idx] for k in bcast_key)
310+
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
295311

296312

297313
def make_scalar_casting_param(

0 commit comments

Comments
 (0)