@@ -41,20 +41,20 @@ def test_transport_cost_improves(method):
4141 assert after_cost < before_cost
4242
4343
44- @pytest .mark .skip (reason = "too unreliable" )
45- @pytest .mark .parametrize ("method" , ["log_sinkhorn" , " sinkhorn" ])
44+ # @pytest.mark.skip(reason="too unreliable")
45+ @pytest .mark .parametrize ("method" , ["sinkhorn" ])
4646def test_assignment_is_optimal (method ):
47- x = keras .random .normal ((16 , 2 ), seed = 0 )
48- p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (x )[0 ]), seed = 0 )
49- optimal_assignments = keras .ops .argsort (p )
47+ y = keras .random .normal ((16 , 2 ), seed = 0 )
48+ p = keras .random .shuffle (keras .ops .arange (keras .ops .shape (y )[0 ]), seed = 0 )
5049
51- y = x [p ]
50+ x = y [p ]
5251
53- x , y , assignments = optimal_transport (
52+ _ , _ , assignments = optimal_transport (
5453 x , y , regularization = 0.1 , seed = 0 , max_steps = 10_000 , method = method , return_assignments = True
5554 )
5655
57- assert_allclose (assignments , optimal_assignments )
56+ # transport is stochastic, so it is expected that a small fraction of assignments do not match
57+ assert keras .ops .sum (assignments == p ) > 14
5858
5959
6060def test_assignment_aligns_with_pot ():
@@ -109,6 +109,7 @@ def test_sinkhorn_plan_aligns_with_pot():
109109
110110 pot_result = sinkhorn (a , b , M , 0.1 )
111111 our_result = sinkhorn_plan (x1 , x2 , regularization = 0.1 , atol = 1e-8 , rtol = 1e-8 )
112+
112113 assert_allclose (pot_result , our_result )
113114
114115
0 commit comments