@@ -151,3 +151,167 @@ def test_cast():
151151 yc64 = x .astype ("complex64" )
152152 with pytest .raises (TypeError , match = "Casting from complex to real is ambiguous" ):
153153 yc64 .astype ("float64" )
154+
155+
156+ def test_dot ():
157+ """Test basic dot product operations."""
158+ # Test matrix-vector dot product (with multiple-letter dim names)
159+ x = xtensor ("x" , dims = ("aa" , "bb" ), shape = (2 , 3 ))
160+ y = xtensor ("y" , dims = ("bb" ,), shape = (3 ,))
161+ z = x .dot (y )
162+ fn = xr_function ([x , y ], z )
163+
164+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("aa" , "bb" ))
165+ y_test = DataArray (np .ones (3 ), dims = ("bb" ,))
166+ z_test = fn (x_test , y_test )
167+ expected = x_test .dot (y_test )
168+ xr_assert_allclose (z_test , expected )
169+
170+ # Test matrix-vector dot product with ellipsis
171+ z = x .dot (y , dim = ...)
172+ fn = xr_function ([x , y ], z )
173+ z_test = fn (x_test , y_test )
174+ expected = x_test .dot (y_test , dim = ...)
175+ xr_assert_allclose (z_test , expected )
176+
177+ # Test matrix-matrix dot product
178+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
179+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
180+ z = x .dot (y )
181+ fn = xr_function ([x , y ], z )
182+
183+ x_test = DataArray (np .add .outer (np .arange (2.0 ), np .arange (3.0 )), dims = ("a" , "b" ))
184+ y_test = DataArray (np .add .outer (np .arange (3.0 ), np .arange (4.0 )), dims = ("b" , "c" ))
185+ z_test = fn (x_test , y_test )
186+ expected = x_test .dot (y_test )
187+ xr_assert_allclose (z_test , expected )
188+
189+ # Test matrix-matrix dot product with string dim
190+ z = x .dot (y , dim = "b" )
191+ fn = xr_function ([x , y ], z )
192+ z_test = fn (x_test , y_test )
193+ expected = x_test .dot (y_test , dim = "b" )
194+ xr_assert_allclose (z_test , expected )
195+
196+ # Test matrix-matrix dot product with list of dims
197+ z = x .dot (y , dim = ["b" ])
198+ fn = xr_function ([x , y ], z )
199+ z_test = fn (x_test , y_test )
200+ expected = x_test .dot (y_test , dim = ["b" ])
201+ xr_assert_allclose (z_test , expected )
202+
203+ # Test matrix-matrix dot product with ellipsis
204+ z = x .dot (y , dim = ...)
205+ fn = xr_function ([x , y ], z )
206+ z_test = fn (x_test , y_test )
207+ expected = x_test .dot (y_test , dim = ...)
208+ xr_assert_allclose (z_test , expected )
209+
210+ # Test a case where there are two dimensions to sum over
211+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
212+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
213+ z = x .dot (y )
214+ fn = xr_function ([x , y ], z )
215+
216+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
217+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
218+ z_test = fn (x_test , y_test )
219+ expected = x_test .dot (y_test )
220+ xr_assert_allclose (z_test , expected )
221+
222+ # Same but with explicit dimensions
223+ z = x .dot (y , dim = ["b" , "c" ])
224+ fn = xr_function ([x , y ], z )
225+ z_test = fn (x_test , y_test )
226+ expected = x_test .dot (y_test , dim = ["b" , "c" ])
227+ xr_assert_allclose (z_test , expected )
228+
229+ # Same but with ellipses
230+ z = x .dot (y , dim = ...)
231+ fn = xr_function ([x , y ], z )
232+ z_test = fn (x_test , y_test )
233+ expected = x_test .dot (y_test , dim = ...)
234+ xr_assert_allclose (z_test , expected )
235+
236+ # Dot product with sum
237+ x_test = DataArray (np .arange (24.0 ).reshape (2 , 3 , 4 ), dims = ("a" , "b" , "c" ))
238+ y_test = DataArray (np .arange (60.0 ).reshape (3 , 4 , 5 ), dims = ("b" , "c" , "d" ))
239+ expected = x_test .dot (y_test , dim = ("a" , "b" , "c" ))
240+
241+ x = xtensor ("x" , dims = ("a" , "b" , "c" ), shape = (2 , 3 , 4 ))
242+ y = xtensor ("y" , dims = ("b" , "c" , "d" ), shape = (3 , 4 , 5 ))
243+ z = x .dot (y , dim = ("a" , "b" , "c" ))
244+ fn = xr_function ([x , y ], z )
245+ z_test = fn (x_test , y_test )
246+ xr_assert_allclose (z_test , expected )
247+
248+ # Dot product with sum in the middle
249+ x_test = DataArray (np .arange (120.0 ).reshape (2 , 3 , 4 , 5 ), dims = ("a" , "b" , "c" , "d" ))
250+ y_test = DataArray (np .arange (360.0 ).reshape (3 , 4 , 5 , 6 ), dims = ("b" , "c" , "d" , "e" ))
251+ expected = x_test .dot (y_test , dim = ("b" , "d" ))
252+ x = xtensor ("x" , dims = ("a" , "b" , "c" , "d" ), shape = (2 , 3 , 4 , 5 ))
253+ y = xtensor ("y" , dims = ("b" , "c" , "d" , "e" ), shape = (3 , 4 , 5 , 6 ))
254+ z = x .dot (y , dim = ("b" , "d" ))
255+ fn = xr_function ([x , y ], z )
256+ z_test = fn (x_test , y_test )
257+ xr_assert_allclose (z_test , expected )
258+
259+ # Same but with first two dims
260+ expected = x_test .dot (y_test , dim = ["a" , "b" ])
261+ z = x .dot (y , dim = ["a" , "b" ])
262+ fn = xr_function ([x , y ], z )
263+ z_test = fn (x_test , y_test )
264+ xr_assert_allclose (z_test , expected )
265+
266+ # Same but with last two
267+ expected = x_test .dot (y_test , dim = ["d" , "e" ])
268+ z = x .dot (y , dim = ["d" , "e" ])
269+ fn = xr_function ([x , y ], z )
270+ z_test = fn (x_test , y_test )
271+ xr_assert_allclose (z_test , expected )
272+
273+ # Same but with every other dim
274+ expected = x_test .dot (y_test , dim = ["a" , "c" , "e" ])
275+ z = x .dot (y , dim = ["a" , "c" , "e" ])
276+ fn = xr_function ([x , y ], z )
277+ z_test = fn (x_test , y_test )
278+ xr_assert_allclose (z_test , expected )
279+
280+ # Test symbolic shapes
281+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (None , 3 )) # First dimension is symbolic
282+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , None )) # Second dimension is symbolic
283+ z = x .dot (y )
284+ fn = xr_function ([x , y ], z )
285+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
286+ y_test = DataArray (np .ones ((3 , 4 )), dims = ("b" , "c" ))
287+ z_test = fn (x_test , y_test )
288+ expected = x_test .dot (y_test )
289+ xr_assert_allclose (z_test , expected )
290+
291+
292+ def test_dot_errors ():
293+ # No matching dimensions
294+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
295+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (3 , 4 ))
296+ with pytest .raises (ValueError , match = "Dimension e not found in either input" ):
297+ x .dot (y , dim = "e" )
298+
299+ # Concrete dimension size mismatches
300+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , 3 ))
301+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (4 , 5 ))
302+ with pytest .raises (
303+ ValueError ,
304+ match = "Size of dim 'b' does not match" ,
305+ ):
306+ x .dot (y )
307+
308+ # Symbolic dimension size mismatches
309+ x = xtensor ("x" , dims = ("a" , "b" ), shape = (2 , None ))
310+ y = xtensor ("y" , dims = ("b" , "c" ), shape = (None , 5 ))
311+ z = x .dot (y )
312+ fn = xr_function ([x , y ], z )
313+ x_test = DataArray (np .ones ((2 , 3 )), dims = ("a" , "b" ))
314+ y_test = DataArray (np .ones ((4 , 5 )), dims = ("b" , "c" ))
315+ # Doesn't fail until the rewrite
316+ with pytest .raises (ValueError , match = "not aligned" ):
317+ fn (x_test , y_test )
0 commit comments