@@ -242,8 +242,23 @@ def test_setitem_masking(shape, data):
242
242
)
243
243
244
244
245
- @given (shape = hh .shapes (), data = st .data ())
246
- def test_getitem_arrays_and_ints (shape , data ):
245
+ @pytest .mark .unvectorized
246
+ @given (shape = hh .shapes (min_dims = 2 ), data = st .data ())
247
+ def test_getitem_arrays_and_ints_1 (shape , data ):
248
+ # min_dims=2 : test multidim `x` arrays
249
+ # index arrays are all 1D
250
+ _test_getitem_arrays_and_ints_1D (shape , data )
251
+
252
+
253
+ @pytest .mark .unvectorized
254
+ @given (shape = hh .shapes (min_dims = 1 ), data = st .data ())
255
+ def test_getitem_arrays_and_ints_2 (shape , data ):
256
+ # min_dims=1 : favor 1D `x` arrays
257
+ # index arrays are all 1D
258
+ _test_getitem_arrays_and_ints_1D (shape , data )
259
+
260
+
261
+ def _test_getitem_arrays_and_ints_1D (shape , data ):
247
262
assume ((len (shape ) > 0 ) and all (sh > 0 for sh in shape ))
248
263
249
264
dtype = xp .int32
@@ -254,10 +269,12 @@ def test_getitem_arrays_and_ints(shape, data):
254
269
arr_index = [data .draw (st .booleans ()) for _ in range (len (shape ))]
255
270
assume (sum (arr_index ) > 0 )
256
271
257
- # draw shapes for index arrays
272
+ # draw shapes for index arrays: NB max_dims=1 ==> 1D indexing arrays ONLY
258
273
if sum (arr_index ) > 0 :
259
274
index_shapes = data .draw (
260
- hh .mutually_broadcastable_shapes (sum (arr_index ), min_dims = 1 , min_side = 1 )
275
+ hh .mutually_broadcastable_shapes (
276
+ sum (arr_index ), min_dims = 1 , max_dims = 1 , min_side = 1
277
+ )
261
278
)
262
279
index_shapes = list (index_shapes )
263
280
@@ -279,19 +296,18 @@ def test_getitem_arrays_and_ints(shape, data):
279
296
# draw an integer
280
297
key .append (data .draw (st .integers (- shape [i ], shape [i ]- 1 )))
281
298
282
-
283
- print (f"??? { x .shape = } { key = } -- { [k if isinstance (k , int ) else k .shape for k in key ]} " )
299
+ print (f"??? { x .shape = } { key = } " )
284
300
285
301
key = tuple (key )
286
302
out = x [key ]
287
303
288
- # XXX: how to properly check
289
- import numpy as np
290
- x_np = np .asarray (x )
291
- out_np = np .asarray (out )
292
- key_np = tuple (k if isinstance (k , int ) else np .asarray (k ) for k in key )
304
+ arrays = [xp .asarray (k ) for k in key ]
305
+ bcast_shape = sh .broadcast_shapes (* [arr .shape for arr in arrays ])
306
+ bcast_key = [xp .broadcast_to (arr , bcast_shape ) for arr in arrays ]
293
307
294
- np .testing .assert_equal (out_np , x_np [key_np ])
308
+ for idx in sh .ndindex (bcast_shape ):
309
+ tpl = tuple (k [idx ] for k in bcast_key )
310
+ assert out [idx ] == x [tpl ], f"failing at { idx = } w/ { key = } "
295
311
296
312
297
313
def make_scalar_casting_param (
0 commit comments