@@ -155,14 +155,14 @@ def test_cast():
155155
156156def test_dot ():
157157 """Test basic dot product operations."""
158- # Test matrix-vector dot product
159- x = xtensor ("x" , dims = ("a " , "b " ), shape = (2 , 3 ))
160- y = xtensor ("y" , dims = ("b " ,), shape = (3 ,))
158+ # Test matrix-vector dot product (with multiple-letter dim names)
159+ x = xtensor ("x" , dims = ("aa " , "bb " ), shape = (2 , 3 ))
160+ y = xtensor ("y" , dims = ("bb " ,), shape = (3 ,))
161161 z = x .dot (y )
162162 fn = xr_function ([x , y ], z )
163163
164- x_test = DataArray (np .ones ((2 , 3 )), dims = ("a " , "b " ))
165- y_test = DataArray (np .ones (3 ), dims = ("b " ,))
164+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("aa " , "bb " ))
165+ y_test = DataArray (np .ones (3 ), dims = ("bb " ,))
166166 z_test = fn (x_test , y_test )
167167 expected = x_test .dot (y_test )
168168 xr_assert_allclose (z_test , expected )
@@ -229,7 +229,6 @@ def test_dot():
229229 # Same but with ellipses
230230 z = x .dot (y , dim = ...)
231231 fn = xr_function ([x , y ], z )
232-
233232 z_test = fn (x_test , y_test )
234233 expected = x_test .dot (y_test , dim = ...)
235234 xr_assert_allclose (z_test , expected )
@@ -247,7 +246,6 @@ def test_dot():
247246 xr_assert_allclose (z_test , expected )
248247
249248 # Dot product with sum in the middle
250- # This is not supported yet
251249 x_test = DataArray (np .arange (120.0 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
252250 y_test = DataArray (np .arange (360.0 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
253251 expected = x_test .dot (y_test , dim = ("b" , "d" ))
0 commit comments