File tree Expand file tree Collapse file tree 2 files changed +10
-18
lines changed
bayesflow/utils/optimal_transport Expand file tree Collapse file tree 2 files changed +10
-18
lines changed Original file line number Diff line number Diff line change @@ -53,22 +53,18 @@ def do_nothing():
5353 pass
5454
5555 def log_steps ():
56- msg = "Log-Sinkhorn-Knopp converged after {:d } steps."
56+ msg = "Log-Sinkhorn-Knopp converged after {} steps."
5757
5858 logging .debug (msg , steps )
5959
6060 def warn_convergence ():
61- marginals = keras .ops .logsumexp (log_plan , axis = 0 )
62- deviations = keras .ops .abs (marginals )
63- badness = 100.0 * keras .ops .exp (keras .ops .max (deviations ))
61+ msg = "Log-Sinkhorn-Knopp did not converge after {} steps."
6462
65- msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
66-
67- logging .warning (msg , max_steps , badness )
63+ logging .warning (msg , max_steps )
6864
6965 def warn_nans ():
70- msg = "Log-Sinkhorn-Knopp produced NaNs."
71- logging .warning (msg )
66+ msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps ."
67+ logging .warning (msg , steps )
7268
7369 keras .ops .cond (contains_nans (log_plan ), warn_nans , do_nothing )
7470 keras .ops .cond (is_converged (log_plan ), log_steps , warn_convergence )
Original file line number Diff line number Diff line change @@ -102,22 +102,18 @@ def do_nothing():
102102 pass
103103
104104 def log_steps ():
105- msg = "Sinkhorn-Knopp converged after {:d } steps."
105+ msg = "Sinkhorn-Knopp converged after {} steps."
106106
107107 logging .info (msg , max_steps )
108108
109109 def warn_convergence ():
110- marginals = keras .ops .sum (plan , axis = 0 )
111- deviations = keras .ops .abs (marginals - 1.0 )
112- badness = 100.0 * keras .ops .max (deviations )
110+ msg = "Sinkhorn-Knopp did not converge after {}."
113111
114- msg = "Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
115-
116- logging .warning (msg , max_steps , badness )
112+ logging .warning (msg , max_steps )
117113
118114 def warn_nans ():
119- msg = "Sinkhorn-Knopp produced NaNs."
120- logging .warning (msg )
115+ msg = "Sinkhorn-Knopp produced NaNs after {} steps ."
116+ logging .warning (msg , steps )
121117
122118 keras .ops .cond (contains_nans (plan ), warn_nans , do_nothing )
123119 keras .ops .cond (is_converged (plan ), log_steps , warn_convergence )
You can’t perform that action at this time.
0 commit comments