Skip to content

Commit 6ee5244

Browse files
re-enable test_assignment_is_optimal() for method='sinkhorn'
1 parent 286bea6 commit 6ee5244

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

tests/test_utils/test_optimal_transport.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,20 @@ def test_transport_cost_improves(method):
4141
assert after_cost < before_cost
4242

4343

44-
@pytest.mark.skip(reason="too unreliable")
45-
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
44+
# @pytest.mark.skip(reason="too unreliable")
45+
@pytest.mark.parametrize("method", ["sinkhorn"])
4646
def test_assignment_is_optimal(method):
47-
x = keras.random.normal((16, 2), seed=0)
48-
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0)
49-
optimal_assignments = keras.ops.argsort(p)
47+
y = keras.random.normal((16, 2), seed=0)
48+
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0)
5049

51-
y = x[p]
50+
x = y[p]
5251

53-
x, y, assignments = optimal_transport(
52+
_, _, assignments = optimal_transport(
5453
x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True
5554
)
5655

57-
assert_allclose(assignments, optimal_assignments)
56+
# transport is stochastic, so it is expected that a small fraction of assignments do not match
57+
assert keras.ops.sum(assignments == p) > 14
5858

5959

6060
def test_assignment_aligns_with_pot():
@@ -109,6 +109,7 @@ def test_sinkhorn_plan_aligns_with_pot():
109109

110110
pot_result = sinkhorn(a, b, M, 0.1)
111111
our_result = sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-8, rtol=1e-8)
112+
112113
assert_allclose(pot_result, our_result)
113114

114115

0 commit comments

Comments
 (0)