Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions bayesflow/utils/optimal_transport/log_sinkhorn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import keras

from .. import logging
from ..tensor_utils import is_symbolic_tensor

from .euclidean import euclidean

Expand All @@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,

log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)

if is_symbolic_tensor(log_plan):
return log_plan

def contains_nans(plan):
return keras.ops.any(keras.ops.isnan(plan))

Expand Down Expand Up @@ -57,22 +53,18 @@ def do_nothing():
pass

def log_steps():
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
msg = "Log-Sinkhorn-Knopp converged after {} steps."

logging.debug(msg, steps)

def warn_convergence():
marginals = keras.ops.logsumexp(log_plan, axis=0)
deviations = keras.ops.abs(marginals)
badness = 100.0 * keras.ops.exp(keras.ops.max(deviations))

msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
msg = "Log-Sinkhorn-Knopp did not converge after {} steps."

logging.warning(msg, max_steps, badness)
logging.warning(msg, max_steps)

def warn_nans():
msg = "Log-Sinkhorn-Knopp produced NaNs."
logging.warning(msg)
msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps."
logging.warning(msg, steps)

keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing)
keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence)
Expand Down
18 changes: 5 additions & 13 deletions bayesflow/utils/optimal_transport/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from bayesflow.types import Tensor

from .. import logging
from ..tensor_utils import is_symbolic_tensor

from .euclidean import euclidean

Expand Down Expand Up @@ -76,9 +75,6 @@ def sinkhorn_plan(
# initialize the transport plan from a gaussian kernel
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))

if is_symbolic_tensor(plan):
return plan

def contains_nans(plan):
return keras.ops.any(keras.ops.isnan(plan))

Expand Down Expand Up @@ -106,22 +102,18 @@ def do_nothing():
pass

def log_steps():
msg = "Sinkhorn-Knopp converged after {:d} steps."
msg = "Sinkhorn-Knopp converged after {} steps."

logging.info(msg, max_steps)

def warn_convergence():
marginals = keras.ops.sum(plan, axis=0)
deviations = keras.ops.abs(marginals - 1.0)
badness = 100.0 * keras.ops.max(deviations)

msg = "Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)."
msg = "Sinkhorn-Knopp did not converge after {}."

logging.warning(msg, max_steps, badness)
logging.warning(msg, max_steps)

def warn_nans():
msg = "Sinkhorn-Knopp produced NaNs."
logging.warning(msg)
msg = "Sinkhorn-Knopp produced NaNs after {} steps."
logging.warning(msg, steps)

keras.ops.cond(contains_nans(plan), warn_nans, do_nothing)
keras.ops.cond(is_converged(plan), log_steps, warn_convergence)
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_bayesian_experimental_design(examples_path):
run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb")


@pytest.mark.skip(reason="requires setting up pyabc")
@pytest.mark.slow
def test_from_abc_to_bayesflow(examples_path):
run_notebook(examples_path / "From_ABC_to_BayesFlow.ipynb")
Expand Down