Skip to content

Commit 02c24c1

Browse files
log_sinkhorn now correctly uses log_plan instead of keras.ops.exp(log_plan), log_sinkhorn_plan returns logits of the transport plan
1 parent 6ee5244 commit 02c24c1

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
99
"""
1010
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
11-
Significantly slower than the unstabilized version, so use only when you need numerical stability.
11+
About 50% slower than the unstabilized version, so use only when you need numerical stability.
1212
"""
1313
log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
14-
assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
14+
assignments = keras.random.categorical(log_plan, num_samples=1, seed=seed)
1515
assignments = keras.ops.squeeze(assignments, axis=1)
1616

1717
return assignments
@@ -20,21 +20,25 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
2020
def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None):
2121
"""
2222
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`.
23-
Slightly slower than the unstabilized version, so use primarily when you need numerical stability.
23+
About 50% slower than the unstabilized version, so use primarily when you need numerical stability.
2424
"""
2525
cost = euclidean(x1, x2)
26+
cost_scaled = -cost / regularization
2627

27-
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
28+
# initialize transport plan from a gaussian kernel
29+
log_plan = cost_scaled - keras.ops.max(cost_scaled)
30+
n, m = keras.ops.shape(log_plan)
31+
32+
log_a = -keras.ops.log(n)
33+
log_b = -keras.ops.log(m)
2834

2935
def contains_nans(plan):
3036
return keras.ops.any(keras.ops.isnan(plan))
3137

3238
def is_converged(plan):
33-
# for convergence, the plan should be doubly stochastic
34-
# NOTE: for small atol and rtol, using rtol_log=0.0 and atol_log=atol + rtol
35-
# is equivalent to the convergence check in the unstabilized version
36-
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=0.0, atol=atol + rtol))
37-
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=0.0, atol=atol + rtol))
39+
# for convergence, the target marginals must match
40+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), log_b, rtol=0.0, atol=rtol + atol))
41+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), log_a, rtol=0.0, atol=rtol + atol))
3842
return conv0 & conv1
3943

4044
def cond(_, plan):
@@ -43,8 +47,8 @@ def cond(_, plan):
4347

4448
def body(steps, plan):
4549
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
46-
plan = keras.ops.log_softmax(plan, axis=0)
47-
plan = keras.ops.log_softmax(plan, axis=1)
50+
plan = plan - keras.ops.logsumexp(plan, axis=0, keepdims=True) + log_b
51+
plan = plan - keras.ops.logsumexp(plan, axis=1, keepdims=True) + log_a
4852

4953
return steps + 1, plan
5054

0 commit comments

Comments
 (0)