Skip to content

Commit 23b1799

Browse files
committed
Cleanup
1 parent 18501bf commit 23b1799

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

tests/xtensor/test_math.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,19 @@ def test_dot():
167167
expected = x_test.dot(y_test)
168168
xr_assert_allclose(z_test, expected)
169169

170+
# Test matrix-vector dot product with ellipsis
171+
z = x.dot(y, dims=...)
172+
fn = xr_function([x, y], z)
173+
z_test = fn(x_test, y_test)
174+
expected = x_test.dot(y_test, dim=...)
175+
xr_assert_allclose(z_test, expected)
176+
170177
# Test matrix-matrix dot product
171178
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
172179
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
173180
z = x.dot(y)
174181
fn = xr_function([x, y], z)
175182

176-
# Use outer product to create test data with diverse values
177183
x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b"))
178184
y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c"))
179185
z_test = fn(x_test, y_test)
@@ -195,14 +201,13 @@ def test_dot():
195201
xr_assert_allclose(z_test, expected)
196202

197203
# Test matrix-matrix dot product with ellipsis
198-
if True:
199-
z = x.dot(y, dims=...)
200-
fn = xr_function([x, y], z)
201-
z_test = fn(x_test, y_test)
202-
expected = x_test.dot(y_test, dim=...)
203-
xr_assert_allclose(z_test, expected)
204-
205-
# Test a case where there are two dimensions to contract over
204+
z = x.dot(y, dims=...)
205+
fn = xr_function([x, y], z)
206+
z_test = fn(x_test, y_test)
207+
expected = x_test.dot(y_test, dim=...)
208+
xr_assert_allclose(z_test, expected)
209+
210+
# Test a case where there are two dimensions to sum over
206211
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
207212
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
208213
z = x.dot(y)
@@ -222,13 +227,12 @@ def test_dot():
222227
xr_assert_allclose(z_test, expected)
223228

224229
# Same but with ellipses
225-
if True:
226-
z = x.dot(y, dims=...)
227-
fn = xr_function([x, y], z)
230+
z = x.dot(y, dims=...)
231+
fn = xr_function([x, y], z)
228232

229-
z_test = fn(x_test, y_test)
230-
expected = x_test.dot(y_test, dim=...)
231-
xr_assert_allclose(z_test, expected)
233+
z_test = fn(x_test, y_test)
234+
expected = x_test.dot(y_test, dim=...)
235+
xr_assert_allclose(z_test, expected)
232236

233237

234238
def test_dot_errors():

0 commit comments

Comments
 (0)