diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 3538eaeff..9fa6dba26 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -1,7 +1,6 @@ import keras from .. import logging -from ..tensor_utils import is_symbolic_tensor from .euclidean import euclidean @@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16) - if is_symbolic_tensor(log_plan): - return log_plan - def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) @@ -57,22 +53,18 @@ def do_nothing(): pass def log_steps(): - msg = "Log-Sinkhorn-Knopp converged after {:d} steps." + msg = "Log-Sinkhorn-Knopp converged after {} steps." logging.debug(msg, steps) def warn_convergence(): - marginals = keras.ops.logsumexp(log_plan, axis=0) - deviations = keras.ops.abs(marginals) - badness = 100.0 * keras.ops.exp(keras.ops.max(deviations)) - - msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)." + msg = "Log-Sinkhorn-Knopp did not converge after {} steps." - logging.warning(msg, max_steps, badness) + logging.warning(msg, max_steps) def warn_nans(): - msg = "Log-Sinkhorn-Knopp produced NaNs." - logging.warning(msg) + msg = "Log-Sinkhorn-Knopp produced NaNs after {} steps." + logging.warning(msg, steps) keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing) keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence) diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index 1efa5ae0b..04c268eb0 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -3,7 +3,6 @@ from bayesflow.types import Tensor from .. import logging -from ..tensor_utils import is_symbolic_tensor from .euclidean import euclidean @@ -76,9 +75,6 @@ def sinkhorn_plan( # initialize the transport plan from a gaussian kernel plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16)) - if is_symbolic_tensor(plan): - return plan - def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) @@ -106,22 +102,18 @@ def do_nothing(): pass def log_steps(): - msg = "Sinkhorn-Knopp converged after {:d} steps." + msg = "Sinkhorn-Knopp converged after {} steps." logging.info(msg, max_steps) def warn_convergence(): - marginals = keras.ops.sum(plan, axis=0) - deviations = keras.ops.abs(marginals - 1.0) - badness = 100.0 * keras.ops.max(deviations) - - msg = "Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)." + msg = "Sinkhorn-Knopp did not converge after {}." - logging.warning(msg, max_steps, badness) + logging.warning(msg, max_steps) def warn_nans(): - msg = "Sinkhorn-Knopp produced NaNs." - logging.warning(msg) + msg = "Sinkhorn-Knopp produced NaNs after {} steps." + logging.warning(msg, steps) keras.ops.cond(contains_nans(plan), warn_nans, do_nothing) keras.ops.cond(is_converged(plan), log_steps, warn_convergence) diff --git a/tests/test_examples/test_examples.py b/tests/test_examples/test_examples.py index 245052636..40135627a 100644 --- a/tests/test_examples/test_examples.py +++ b/tests/test_examples/test_examples.py @@ -9,6 +9,7 @@ def test_bayesian_experimental_design(examples_path): run_notebook(examples_path / "Bayesian_Experimental_Design.ipynb") +@pytest.mark.skip(reason="requires setting up pyabc") @pytest.mark.slow def test_from_abc_to_bayesflow(examples_path): run_notebook(examples_path / "From_ABC_to_BayesFlow.ipynb")