@@ -41,8 +41,7 @@ def test_transport_cost_improves(method):
4141 assert after_cost < before_cost
4242
4343
44- # @pytest.mark.skip(reason="too unreliable")
45- @pytest .mark .parametrize ("method" , ["sinkhorn" ])
44+ @pytest .mark .parametrize ("method" , ["log_sinkhorn" , "sinkhorn" ])
4645def test_assignment_is_optimal (method ):
4746 y = keras .random .normal ((16 , 2 ), seed = 0 )
4847 p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (y )[0 ]), seed = 0 )
@@ -72,8 +71,8 @@ def test_assignment_aligns_with_pot():
7271 M = x [:, None ] - y [None , :]
7372 M = keras .ops .norm (M , axis = - 1 )
7473
75- pot_plan = sinkhorn_log (a , b , M , reg = 1e-3 , numItermax = 10_000 , stopThr = 1e-99 )
76- pot_assignments = keras .random .categorical (pot_plan , num_samples = 1 , seed = 0 )
74+ pot_plan = sinkhorn_log (a , b , M , reg = 1e-3 , stopThr = 1e-7 )
75+ pot_assignments = keras .random .categorical (keras . ops . log ( pot_plan ) , num_samples = 1 , seed = 0 )
7776 pot_assignments = keras .ops .squeeze (pot_assignments , axis = - 1 )
7877
7978 _ , _ , assignments = optimal_transport (x , y , regularization = 1e-3 , seed = 0 , max_steps = 10_000 , return_assignments = True )
@@ -107,8 +106,8 @@ def test_sinkhorn_plan_aligns_with_pot():
107106 b = keras .ops .ones (20 ) / 20
108107 M = euclidean (x1 , x2 )
109108
110- pot_result = sinkhorn (a , b , M , 0.1 )
111- our_result = sinkhorn_plan (x1 , x2 , regularization = 0.1 , atol = 1e-8 , rtol = 1e-8 )
109+ pot_result = sinkhorn (a , b , M , 0.1 , stopThr = 1e-8 )
110+ our_result = sinkhorn_plan (x1 , x2 , regularization = 0.1 , rtol = 1e-7 )
112111
113112 assert_allclose (pot_result , our_result )
114113
@@ -128,3 +127,56 @@ def test_sinkhorn_plan_matches_analytical_result():
128127 expected = keras .ops .outer (marginal_x1 , marginal_x2 )
129128
130129 assert_allclose (result , expected )
130+
131+
132+ def test_log_sinkhorn_plan_correct_marginals ():
133+ from bayesflow .utils .optimal_transport .log_sinkhorn import log_sinkhorn_plan
134+
135+ x1 = keras .random .normal ((10 , 2 ), seed = 0 )
136+ x2 = keras .random .normal ((20 , 2 ), seed = 1 )
137+
138+ assert keras .ops .all (
139+ keras .ops .isclose (keras .ops .logsumexp (log_sinkhorn_plan (x1 , x2 ), axis = 0 ), - keras .ops .log (20 ), atol = 1e-3 )
140+ )
141+ assert keras .ops .all (
142+ keras .ops .isclose (keras .ops .logsumexp (log_sinkhorn_plan (x1 , x2 ), axis = 1 ), - keras .ops .log (10 ), atol = 1e-3 )
143+ )
144+
145+
146+ def test_log_sinkhorn_plan_aligns_with_pot ():
147+ try :
148+ from ot .bregman import sinkhorn_log
149+ except (ImportError , ModuleNotFoundError ):
150+ pytest .skip ("Need to install POT to run this test." )
151+
152+ from bayesflow .utils .optimal_transport .log_sinkhorn import log_sinkhorn_plan
153+ from bayesflow .utils .optimal_transport .euclidean import euclidean
154+
155+ x1 = keras .random .normal ((100 , 3 ), seed = 0 )
156+ x2 = keras .random .normal ((200 , 3 ), seed = 1 )
157+
158+ a = keras .ops .ones (100 ) / 100
159+ b = keras .ops .ones (200 ) / 200
160+ M = euclidean (x1 , x2 )
161+
162+ pot_result = keras .ops .log (sinkhorn_log (a , b , M , 0.1 , stopThr = 1e-7 )) # sinkhorn_log returns probabilities
163+ our_result = log_sinkhorn_plan (x1 , x2 , regularization = 0.1 )
164+
165+ assert_allclose (pot_result , our_result )
166+
167+
168+ def test_log_sinkhorn_plan_matches_analytical_result ():
169+ from bayesflow .utils .optimal_transport .log_sinkhorn import log_sinkhorn_plan
170+
171+ x1 = keras .ops .ones (16 )
172+ x2 = keras .ops .ones (64 )
173+
174+ marginal_x1 = keras .ops .ones (16 ) / 16
175+ marginal_x2 = keras .ops .ones (64 ) / 64
176+
177+ result = keras .ops .exp (log_sinkhorn_plan (x1 , x2 , regularization = 0.1 ))
178+
179+ # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
180+ expected = keras .ops .outer (marginal_x1 , marginal_x2 )
181+
182+ assert_allclose (result , expected )
0 commit comments