Skip to content

Commit 5efb00b

Browse files
committed
fix ultra-strict convergence criterion in log_sinkhorn_plan
1 parent 7edf36d commit 5efb00b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def contains_nans(plan):
3131

3232
def is_converged(plan):
3333
# for convergence, the plan should be doubly stochastic
34-
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol))
35-
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol))
34+
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=0.0, atol=atol + rtol))
35+
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=0.0, atol=atol + rtol))
3636
return conv0 & conv1
3737

3838
def cond(_, plan):

0 commit comments

Comments
 (0)