diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 65aa25e3f..7a097d340 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -39,11 +39,11 @@ class FlowMatching(InferenceNetwork): } OPTIMAL_TRANSPORT_DEFAULT_CONFIG = { - "method": "sinkhorn", - "cost": "euclidean", + "method": "log_sinkhorn", "regularization": 0.1, "max_steps": 100, - "tolerance": 1e-4, + "atol": 1e-5, + "rtol": 1e-4, } INTEGRATE_DEFAULT_CONFIG = { diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 95ea69eb7..2a65d039e 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -1,7 +1,6 @@ import keras from .. import logging -from ..tensor_utils import is_symbolic_tensor from .euclidean import euclidean @@ -27,9 +26,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16) - if is_symbolic_tensor(log_plan): - return log_plan - def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) @@ -59,7 +55,7 @@ def do_nothing(): def log_steps(): msg = "Log-Sinkhorn-Knopp converged after {:d} steps." - logging.info(msg, steps) + logging.debug(msg, steps) def warn_convergence(): marginals = keras.ops.logsumexp(log_plan, axis=0) diff --git a/pyproject.toml b/pyproject.toml index 33fdc2786..bdb61fcca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "bayesflow" -version = "2.0.0" +version = "2.0.1" authors = [{ name = "The BayesFlow Team" }] classifiers = [ "Development Status :: 5 - Production/Stable",