@@ -167,13 +167,19 @@ def test_dot():
167167 expected = x_test .dot (y_test )
168168 xr_assert_allclose (z_test , expected )
169169
170+ # Test matrix-vector dot product with ellipsis
171+ z = x .dot (y , dims = ...)
172+ fn = xr_function ([x , y ], z )
173+ z_test = fn (x_test , y_test )
174+ expected = x_test .dot (y_test , dim = ...)
175+ xr_assert_allclose (z_test , expected )
176+
170177 # Test matrix-matrix dot product
171178 x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
172179 y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
173180 z = x .dot (y )
174181 fn = xr_function ([x , y ], z )
175182
176- # Use outer product to create test data with diverse values
177183 x_test = DataArray (np .add .outer (np .arange (2.0 ), np .arange (3.0 )), dims = ("a" , "b" ))
178184 y_test = DataArray (np .add .outer (np .arange (3.0 ), np .arange (4.0 )), dims = ("b" , "c" ))
179185 z_test = fn (x_test , y_test )
@@ -195,14 +201,13 @@ def test_dot():
195201 xr_assert_allclose (z_test , expected )
196202
197203 # Test matrix-matrix dot product with ellipsis
198- if True :
199- z = x .dot (y , dims = ...)
200- fn = xr_function ([x , y ], z )
201- z_test = fn (x_test , y_test )
202- expected = x_test .dot (y_test , dim = ...)
203- xr_assert_allclose (z_test , expected )
204-
205- # Test a case where there are two dimensions to contract over
204+ z = x .dot (y , dims = ...)
205+ fn = xr_function ([x , y ], z )
206+ z_test = fn (x_test , y_test )
207+ expected = x_test .dot (y_test , dim = ...)
208+ xr_assert_allclose (z_test , expected )
209+
210+ # Test a case where there are two dimensions to sum over
206211 x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
207212 y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
208213 z = x .dot (y )
@@ -222,13 +227,12 @@ def test_dot():
222227 xr_assert_allclose (z_test , expected )
223228
224229 # Same but with ellipses
225- if True :
226- z = x .dot (y , dims = ...)
227- fn = xr_function ([x , y ], z )
230+ z = x .dot (y , dims = ...)
231+ fn = xr_function ([x , y ], z )
228232
229- z_test = fn (x_test , y_test )
230- expected = x_test .dot (y_test , dim = ...)
231- xr_assert_allclose (z_test , expected )
233+ z_test = fn (x_test , y_test )
234+ expected = x_test .dot (y_test , dim = ...)
235+ xr_assert_allclose (z_test , expected )
232236
233237
234238def test_dot_errors ():
0 commit comments