@@ -52,7 +52,7 @@ def test_single_adv_indexing_on_existing_dim():
5252 idx_test = np .array ([0 , 1 , 0 , 2 ], dtype = int )
5353 xidx_test = DataArray (idx_test , dims = ("a" ,))
5454
55- # Three equivalent ways of indexing a->a
55+ # Equivalent ways of indexing a->a
5656 y = x [idx ]
5757 fn = xr_function ([x , idx ], y )
5858 res = fn (x_test , idx_test )
@@ -65,6 +65,12 @@ def test_single_adv_indexing_on_existing_dim():
6565 expected_res = x_test [(("a" , idx_test ),)]
6666 xr_assert_allclose (res , expected_res )
6767
68+ y = x [((("a" ,), idx ),)]
69+ fn = xr_function ([x , idx ], y )
70+ res = fn (x_test , idx_test )
71+ expected_res = x_test [((("a" ,), idx_test ),)]
72+ xr_assert_allclose (res , expected_res )
73+
6874 y = x [xidx ]
6975 fn = xr_function ([x , xidx ], y )
7076 res = fn (x_test , xidx_test )
@@ -81,13 +87,19 @@ def test_single_vector_indexing_on_new_dim():
8187 idx_test = np .array ([0 , 1 , 0 , 2 ], dtype = int )
8288 xidx_test = DataArray (idx_test , dims = ("a" ,))
8389
84- # Two equivalent ways of indexing a->new_a
90+ # Equivalent ways of indexing a->new_a
8591 y = x [(("new_a" , idx ),)]
8692 fn = xr_function ([x , idx ], y )
8793 res = fn (x_test , idx_test )
8894 expected_res = x_test [(("new_a" , idx_test ),)]
8995 xr_assert_allclose (res , expected_res )
9096
97+ y = x [((["new_a" ], idx ),)]
98+ fn = xr_function ([x , idx ], y )
99+ res = fn (x_test , idx_test )
100+ expected_res = x_test [((["new_a" ], idx_test ),)]
101+ xr_assert_allclose (res , expected_res )
102+
91103 y = x [xidx .rename (a = "new_a" )]
92104 fn = xr_function ([x , xidx ], y )
93105 res = fn (x_test , xidx_test )
@@ -176,6 +188,34 @@ def test_matrix_indexing():
176188 xr_assert_allclose (res , expected_res )
177189
178190
191+ def test_assign_multiple_out_dims ():
192+ x = xtensor ("x" , shape = (5 , 7 ), dims = ("a" , "b" ))
193+ idx1 = tensor ("idx1" , dtype = int , shape = (4 , 3 ))
194+ idx2 = tensor ("idx2" , dtype = int , shape = (3 , 2 ))
195+ out = x [(("out1" , "out2" ), idx1 ), (["out2" , "out3" ], idx2 )]
196+
197+ fn = xr_function ([x , idx1 , idx2 ], out )
198+
199+ rng = np .random .default_rng ()
200+ x_test = xr_arange_like (x )
201+ idx1_test = rng .binomial (n = 4 , p = 0.5 , size = (4 , 3 ))
202+ idx2_test = rng .binomial (n = 4 , p = 0.5 , size = (3 , 2 ))
203+ res = fn (x_test , idx1_test , idx2_test )
204+ expected_res = x_test [(("out1" , "out2" ), idx1_test ), (["out2" , "out3" ], idx2_test )]
205+ xr_assert_allclose (res , expected_res )
206+
207+
208+ def test_assign_dims_xtensor_fails ():
209+ x = xtensor ("x" , shape = (5 , 7 ), dims = ("a" , "b" ))
210+ idx1 = xtensor ("idx1" , dtype = int , shape = (4 ,), dims = ("c" ,))
211+
212+ with pytest .raises (
213+ TypeError ,
214+ match = "Giving a dimension name to an XTensorVariable indexer is not supported" ,
215+ ):
216+ x [("d" , idx1 ),]
217+
218+
179219class TestVectorizedIndexingNotAllowedToBroadcast :
180220 def test_compile_time_error (self ):
181221 x = xtensor (dims = ("a" , "b" ), shape = (3 , 5 ))
0 commit comments