@@ -168,7 +168,7 @@ def test_dot():
168168 xr_assert_allclose (z_test , expected )
169169
170170 # Test matrix-vector dot product with ellipsis
171- z = x .dot (y , dims = ...)
171+ z = x .dot (y , dim = ...)
172172 fn = xr_function ([x , y ], z )
173173 z_test = fn (x_test , y_test )
174174 expected = x_test .dot (y_test , dim = ...)
@@ -186,22 +186,22 @@ def test_dot():
186186 expected = x_test .dot (y_test )
187187 xr_assert_allclose (z_test , expected )
188188
189- # Test matrix-matrix dot product with string dims
190- z = x .dot (y , dims = "b" )
189+ # Test matrix-matrix dot product with string dim
190+ z = x .dot (y , dim = "b" )
191191 fn = xr_function ([x , y ], z )
192192 z_test = fn (x_test , y_test )
193193 expected = x_test .dot (y_test , dim = "b" )
194194 xr_assert_allclose (z_test , expected )
195195
196196 # Test matrix-matrix dot product with list of dims
197- z = x .dot (y , dims = ["b" ])
197+ z = x .dot (y , dim = ["b" ])
198198 fn = xr_function ([x , y ], z )
199199 z_test = fn (x_test , y_test )
200200 expected = x_test .dot (y_test , dim = ["b" ])
201201 xr_assert_allclose (z_test , expected )
202202
203203 # Test matrix-matrix dot product with ellipsis
204- z = x .dot (y , dims = ...)
204+ z = x .dot (y , dim = ...)
205205 fn = xr_function ([x , y ], z )
206206 z_test = fn (x_test , y_test )
207207 expected = x_test .dot (y_test , dim = ...)
@@ -220,28 +220,49 @@ def test_dot():
220220 xr_assert_allclose (z_test , expected )
221221
222222 # Same but with explicit dimensions
223- z = x .dot (y , dims = ["b" , "c" ])
223+ z = x .dot (y , dim = ["b" , "c" ])
224224 fn = xr_function ([x , y ], z )
225225 z_test = fn (x_test , y_test )
226226 expected = x_test .dot (y_test , dim = ["b" , "c" ])
227227 xr_assert_allclose (z_test , expected )
228228
229229 # Same but with ellipses
230- z = x .dot (y , dims = ...)
230+ z = x .dot (y , dim = ...)
231231 fn = xr_function ([x , y ], z )
232232
233233 z_test = fn (x_test , y_test )
234234 expected = x_test .dot (y_test , dim = ...)
235235 xr_assert_allclose (z_test , expected )
236236
237+ # Dot product with sum
238+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
239+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
240+ expected = x_test .dot (y_test , dim = ("a" , "b" , "c" ))
241+
242+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
243+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
244+ z = x .dot (y , dim = ("a" , "b" , "c" ))
245+ fn = xr_function ([x , y ], z )
246+ z_test = fn (x_test , y_test )
247+ xr_assert_allclose (z_test , expected )
248+
249+ return
250+ # Dot product with sum in the middle
251+ # This is not supported yet
252+ x_test = DataArray (np .arange (120 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
253+ y_test = DataArray (np .arange (360 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
254+ expected = x_test .dot (y_test , dim = ("b" , "d" ))
255+ x = xtensor ("x" , dims = ("a" , "b" , "c" , "d" ), shape = (2 , 3 , 4 , 5 ))
256+ y = xtensor ("y" , dims = ("b" , "c" , "d" , "e" ), shape = (3 , 4 , 5 , 6 ))
257+ z = x .dot (y , dim = ("b" , "d" ))
258+ fn = xr_function ([x , y ], z )
259+ z_test = fn (x_test , y_test )
260+ xr_assert_allclose (z_test , expected )
261+
237262
238263def test_dot_errors ():
239264 x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
240265 y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
241- with pytest .raises (ValueError , match = "Dimension c not found in first input" ):
242- x .dot (y , dims = ["c" ])
243- with pytest .raises (ValueError , match = "Dimension a not found in second input" ):
244- x .dot (y , dims = ["a" ])
245266
246267 # Test a case where there are no matching dimensions
247268 x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
0 commit comments