88def log_sinkhorn (x1 , x2 , seed : int = None , ** kwargs ):
99 """
1010 Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
11- Significantly slower than the unstabilized version, so use only when you need numerical stability.
11+ About 50% slower than the unstabilized version, so use only when you need numerical stability.
1212 """
1313 log_plan = log_sinkhorn_plan (x1 , x2 , ** kwargs )
14- assignments = keras .random .categorical (keras . ops . exp ( log_plan ) , num_samples = 1 , seed = seed )
14+ assignments = keras .random .categorical (log_plan , num_samples = 1 , seed = seed )
1515 assignments = keras .ops .squeeze (assignments , axis = 1 )
1616
1717 return assignments
@@ -20,21 +20,25 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
2020def log_sinkhorn_plan (x1 , x2 , regularization : float = 1.0 , rtol = 1e-5 , atol = 1e-8 , max_steps = None ):
2121 """
2222 Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`.
23- Slightly slower than the unstabilized version, so use primarily when you need numerical stability.
23+ About 50% slower than the unstabilized version, so use primarily when you need numerical stability.
2424 """
2525 cost = euclidean (x1 , x2 )
26+ cost_scaled = - cost / regularization
2627
27- log_plan = cost / - (regularization * keras .ops .mean (cost ) + 1e-16 )
28+ # initialize transport plan from a gaussian kernel
29+ log_plan = cost_scaled - keras .ops .max (cost_scaled )
30+ n , m = keras .ops .shape (log_plan )
31+
32+ log_a = - keras .ops .log (n )
33+ log_b = - keras .ops .log (m )
2834
2935 def contains_nans (plan ):
3036 return keras .ops .any (keras .ops .isnan (plan ))
3137
3238 def is_converged (plan ):
33- # for convergence, the plan should be doubly stochastic
34- # NOTE: for small atol and rtol, using rtol_log=0.0 and atol_log=atol + rtol
35- # is equivalent to the convergence check in the unstabilized version
36- conv0 = keras .ops .all (keras .ops .isclose (keras .ops .logsumexp (plan , axis = 0 ), 0.0 , rtol = 0.0 , atol = atol + rtol ))
37- conv1 = keras .ops .all (keras .ops .isclose (keras .ops .logsumexp (plan , axis = 1 ), 0.0 , rtol = 0.0 , atol = atol + rtol ))
39+ # for convergence, the target marginals must match
40+ conv0 = keras .ops .all (keras .ops .isclose (keras .ops .logsumexp (plan , axis = 0 ), log_b , rtol = 0.0 , atol = rtol + atol ))
41+ conv1 = keras .ops .all (keras .ops .isclose (keras .ops .logsumexp (plan , axis = 1 ), log_a , rtol = 0.0 , atol = rtol + atol ))
3842 return conv0 & conv1
3943
4044 def cond (_ , plan ):
@@ -43,8 +47,8 @@ def cond(_, plan):
4347
4448 def body (steps , plan ):
4549 # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
46- plan = keras .ops .log_softmax (plan , axis = 0 )
47- plan = keras .ops .log_softmax (plan , axis = 1 )
50+ plan = plan - keras .ops .logsumexp (plan , axis = 0 , keepdims = True ) + log_b
51+ plan = plan - keras .ops .logsumexp (plan , axis = 1 , keepdims = True ) + log_a
4852
4953 return steps + 1 , plan
5054
0 commit comments