@@ -295,3 +295,39 @@ def test_scalar_integer_indexing(dims_order):
295295 expected_res2 = x_test [tuple (idxs )]
296296 xr_assert_allclose (res1 , expected_res1 )
297297 xr_assert_allclose (res2 , expected_res2 )
298+
299+
300+ def test_unsupported_boolean_indexing ():
301+ x = xtensor (dims = ("a" , "b" ), shape = (3 , 5 ))
302+
303+ mat_idx = xtensor ("idx" , dtype = bool , shape = (4 , 2 ), dims = ("a" , "b" ))
304+ scalar_idx = mat_idx .isel (a = 0 , b = 1 )
305+
306+ for idx in (mat_idx , mat_idx .values , scalar_idx , scalar_idx .values ):
307+ with pytest .raises (
308+ NotImplementedError ,
309+ match = "Only 1d boolean indexing arrays are supported" ,
310+ ):
311+ x [idx ]
312+
313+
314+ def test_boolean_indexing ():
315+ x = xtensor ("x" , shape = (8 , 7 ), dims = ("a" , "b" ))
316+ bool_idx = xtensor ("bool_idx" , dtype = bool , shape = (8 ,), dims = ("a" ,))
317+ int_idx = xtensor ("int_idx" , dtype = int , shape = (4 , 3 ), dims = ("a" , "new_dim" ))
318+
319+ out_vectorized = x [bool_idx , int_idx ]
320+ out_orthogonal = x [bool_idx , int_idx .rename (a = "b" )]
321+ fn = xr_function ([x , bool_idx , int_idx ], [out_vectorized , out_orthogonal ])
322+
323+ x_test = xr_arange_like (x )
324+ bool_idx_test = DataArray (np .array ([True , False ] * 4 , dtype = bool ), dims = ("a" ,))
325+ int_idx_test = DataArray (
326+ np .random .binomial (n = 4 , p = 0.5 , size = (4 , 3 )),
327+ dims = ("a" , "new_dim" ),
328+ )
329+ res1 , res2 = fn (x_test , bool_idx_test , int_idx_test )
330+ expected_res1 = x_test [bool_idx_test , int_idx_test ]
331+ expected_res2 = x_test [bool_idx_test , int_idx_test .rename (a = "b" )]
332+ xr_assert_allclose (res1 , expected_res1 )
333+ xr_assert_allclose (res2 , expected_res2 )
0 commit comments