Skip to content

Commit f25df63

Browse files
add unit tests for log_sinkhorn_plan
1 parent 02c24c1 commit f25df63

File tree

1 file changed

+58
-6
lines changed

1 file changed

+58
-6
lines changed

tests/test_utils/test_optimal_transport.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def test_transport_cost_improves(method):
4141
assert after_cost < before_cost
4242

4343

44-
# @pytest.mark.skip(reason="too unreliable")
45-
@pytest.mark.parametrize("method", ["sinkhorn"])
44+
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
4645
def test_assignment_is_optimal(method):
4746
y = keras.random.normal((16, 2), seed=0)
4847
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0)
@@ -72,8 +71,8 @@ def test_assignment_aligns_with_pot():
7271
M = x[:, None] - y[None, :]
7372
M = keras.ops.norm(M, axis=-1)
7473

75-
pot_plan = sinkhorn_log(a, b, M, reg=1e-3, numItermax=10_000, stopThr=1e-99)
76-
pot_assignments = keras.random.categorical(pot_plan, num_samples=1, seed=0)
74+
pot_plan = sinkhorn_log(a, b, M, reg=1e-3, stopThr=1e-7)
75+
pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0)
7776
pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1)
7877

7978
_, _, assignments = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True)
@@ -107,8 +106,8 @@ def test_sinkhorn_plan_aligns_with_pot():
107106
b = keras.ops.ones(20) / 20
108107
M = euclidean(x1, x2)
109108

110-
pot_result = sinkhorn(a, b, M, 0.1)
111-
our_result = sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-8, rtol=1e-8)
109+
pot_result = sinkhorn(a, b, M, 0.1, stopThr=1e-8)
110+
our_result = sinkhorn_plan(x1, x2, regularization=0.1, rtol=1e-7)
112111

113112
assert_allclose(pot_result, our_result)
114113

@@ -128,3 +127,56 @@ def test_sinkhorn_plan_matches_analytical_result():
128127
expected = keras.ops.outer(marginal_x1, marginal_x2)
129128

130129
assert_allclose(result, expected)
130+
131+
132+
def test_log_sinkhorn_plan_correct_marginals():
133+
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan
134+
135+
x1 = keras.random.normal((10, 2), seed=0)
136+
x2 = keras.random.normal((20, 2), seed=1)
137+
138+
assert keras.ops.all(
139+
keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=0), -keras.ops.log(20), atol=1e-3)
140+
)
141+
assert keras.ops.all(
142+
keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=1), -keras.ops.log(10), atol=1e-3)
143+
)
144+
145+
146+
def test_log_sinkhorn_plan_aligns_with_pot():
147+
try:
148+
from ot.bregman import sinkhorn_log
149+
except (ImportError, ModuleNotFoundError):
150+
pytest.skip("Need to install POT to run this test.")
151+
152+
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan
153+
from bayesflow.utils.optimal_transport.euclidean import euclidean
154+
155+
x1 = keras.random.normal((100, 3), seed=0)
156+
x2 = keras.random.normal((200, 3), seed=1)
157+
158+
a = keras.ops.ones(100) / 100
159+
b = keras.ops.ones(200) / 200
160+
M = euclidean(x1, x2)
161+
162+
pot_result = keras.ops.log(sinkhorn_log(a, b, M, 0.1, stopThr=1e-7)) # sinkhorn_log returns probabilities
163+
our_result = log_sinkhorn_plan(x1, x2, regularization=0.1)
164+
165+
assert_allclose(pot_result, our_result)
166+
167+
168+
def test_log_sinkhorn_plan_matches_analytical_result():
169+
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan
170+
171+
x1 = keras.ops.ones(16)
172+
x2 = keras.ops.ones(64)
173+
174+
marginal_x1 = keras.ops.ones(16) / 16
175+
marginal_x2 = keras.ops.ones(64) / 64
176+
177+
result = keras.ops.exp(log_sinkhorn_plan(x1, x2, regularization=0.1))
178+
179+
# If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
180+
expected = keras.ops.outer(marginal_x1, marginal_x2)
181+
182+
assert_allclose(result, expected)

0 commit comments

Comments
 (0)