@@ -246,11 +246,10 @@ def test_dot():
246246 z_test = fn (x_test , y_test )
247247 xr_assert_allclose (z_test , expected )
248248
249- return
250249 # Dot product with sum in the middle
251250 # This is not supported yet
252- x_test = DataArray (np .arange (120 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
253- y_test = DataArray (np .arange (360 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
251+ x_test = DataArray (np .arange (120.0 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
252+ y_test = DataArray (np .arange (360.0 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
254253 expected = x_test .dot (y_test , dim = ("b" , "d" ))
255254 x = xtensor ("x" , dims = ("a" , "b" , "c" , "d" ), shape = (2 , 3 , 4 , 5 ))
256255 y = xtensor ("y" , dims = ("b" , "c" , "d" , "e" ), shape = (3 , 4 , 5 , 6 ))
@@ -259,6 +258,27 @@ def test_dot():
259258 z_test = fn (x_test , y_test )
260259 xr_assert_allclose (z_test , expected )
261260
261+ # Same but with first two dims
262+ expected = x_test .dot (y_test , dim = ["a" , "b" ])
263+ z = x .dot (y , dim = ["a" , "b" ])
264+ fn = xr_function ([x , y ], z )
265+ z_test = fn (x_test , y_test )
266+ xr_assert_allclose (z_test , expected )
267+
268+ # Same but with last two
269+ expected = x_test .dot (y_test , dim = ["d" , "e" ])
270+ z = x .dot (y , dim = ["d" , "e" ])
271+ fn = xr_function ([x , y ], z )
272+ z_test = fn (x_test , y_test )
273+ xr_assert_allclose (z_test , expected )
274+
275+ # Same but with every other dim
276+ expected = x_test .dot (y_test , dim = ["a" , "c" , "e" ])
277+ z = x .dot (y , dim = ["a" , "c" , "e" ])
278+ fn = xr_function ([x , y ], z )
279+ z_test = fn (x_test , y_test )
280+ xr_assert_allclose (z_test , expected )
281+
262282
263283def test_dot_errors ():
264284 x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
@@ -273,7 +293,8 @@ def test_dot_errors():
273293 x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
274294 y = xtensor ("y" , dims = ("b" , "c" ), shape = (4 , 5 ))
275295 with pytest .raises (
276- ValueError , match = "Input arrays have inconsistent type shape along the axes"
296+ ValueError ,
297+ match = "Size of label 'b' for operand 1.*does not match previous terms" ,
277298 ):
278299 z = x .dot (y )
279300 fn = function ([x , y ], z )
0 commit comments