Skip to content

Commit 1305047

Browse files
committed
fix sinkhorn and log_sinkhorn message formatting for jax by making the warning message worse
1 parent 6a9bcc7 commit 1305047

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,18 @@ def do_nothing():
5353
pass
5454

5555
def log_steps():
56-
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
56+
msg = "Log-Sinkhorn-Knopp converged after {} steps."
5757

5858
logging.debug(msg, steps)
5959

6060
def warn_convergence():
61-
marginals = keras.ops.logsumexp(log_plan, axis=0)
62-
deviations = keras.ops.abs(marginals)
63-
badness = 100.0 * keras.ops.exp(keras.ops.max(deviations))
61+
msg = "Log-Sinkhorn-Knopp did not converge after {} steps."
6462

65-
msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
66-
67-
logging.warning(msg, max_steps, badness)
63+
logging.warning(msg, max_steps)
6864

6965
def warn_nans():
70-
msg = "Log-Sinkhorn-Knopp produced NaNs."
71-
logging.warning(msg)
66+
msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps."
67+
logging.warning(msg, steps)
7268

7369
keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
7470
keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,18 @@ def do_nothing():
102102
pass
103103

104104
def log_steps():
105-
msg = "Sinkhorn-Knopp converged after {:d} steps."
105+
msg = "Sinkhorn-Knopp converged after {} steps."
106106

107107
logging.info(msg, max_steps)
108108

109109
def warn_convergence():
110-
marginals = keras.ops.sum(plan, axis=0)
111-
deviations = keras.ops.abs(marginals - 1.0)
112-
badness = 100.0 * keras.ops.max(deviations)
110+
msg = "Sinkhorn-Knopp did not converge after {}."
113111

114-
msg = "Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
115-
116-
logging.warning(msg, max_steps, badness)
112+
logging.warning(msg, max_steps)
117113

118114
def warn_nans():
119-
msg = "Sinkhorn-Knopp produced NaNs."
120-
logging.warning(msg)
115+
msg = "Sinkhorn-Knopp produced NaNs after {} steps."
116+
logging.warning(msg, steps)
121117

122118
keras.ops.cond(contains_nans(plan), warn_nans, do_nothing)
123119
keras.ops.cond(is_converged(plan), log_steps, warn_convergence)

0 commit comments

Comments
 (0)