Skip to content

Commit 6ce20d0

Browse files
committed
Cleanup
1 parent e44ca9f commit 6ce20d0

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

tests/xtensor/test_math.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ def test_cast():
155155

156156
def test_dot():
157157
"""Test basic dot product operations."""
158-
# Test matrix-vector dot product
159-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
160-
y = xtensor("y", dims=("b",), shape=(3,))
158+
# Test matrix-vector dot product (with multiple-letter dim names)
159+
x = xtensor("x", dims=("aa", "bb"), shape=(2, 3))
160+
y = xtensor("y", dims=("bb",), shape=(3,))
161161
z = x.dot(y)
162162
fn = xr_function([x, y], z)
163163

164-
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
165-
y_test = DataArray(np.ones(3), dims=("b",))
164+
x_test = DataArray(np.ones((2, 3)), dims=("aa", "bb"))
165+
y_test = DataArray(np.ones(3), dims=("bb",))
166166
z_test = fn(x_test, y_test)
167167
expected = x_test.dot(y_test)
168168
xr_assert_allclose(z_test, expected)
@@ -229,7 +229,6 @@ def test_dot():
229229
# Same but with ellipses
230230
z = x.dot(y, dim=...)
231231
fn = xr_function([x, y], z)
232-
233232
z_test = fn(x_test, y_test)
234233
expected = x_test.dot(y_test, dim=...)
235234
xr_assert_allclose(z_test, expected)
@@ -247,7 +246,6 @@ def test_dot():
247246
xr_assert_allclose(z_test, expected)
248247

249248
# Dot product with sum in the middle
250-
# This is not supported yet
251249
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
252250
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
253251
expected = x_test.dot(y_test, dim=("b", "d"))

0 commit comments

Comments
 (0)