From e95f386156b12aab869c5a2e285ff8535c5dc9d3 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 4 Jul 2025 11:22:22 +0200 Subject: [PATCH 01/15] fix numerical stability issues in sinkhorn plan --- bayesflow/utils/optimal_transport/sinkhorn.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index f7e0ba835..466b84e84 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -42,7 +42,7 @@ def sinkhorn_plan( x1: Tensor, x2: Tensor, regularization: float = 1.0, - max_steps: int = 10_000, + max_steps: int = None, rtol: float = 1e-5, atol: float = 1e-8, ) -> Tensor: @@ -59,7 +59,7 @@ def sinkhorn_plan( Controls the standard deviation of the Gaussian kernel. :param max_steps: Maximum number of iterations, or None to run until convergence. - Default: 10_000 + Default: None :param rtol: Relative tolerance for convergence. Default: 1e-5. @@ -73,7 +73,11 @@ def sinkhorn_plan( cost = euclidean(x1, x2) # initialize the transport plan from a gaussian kernel - plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16)) + logits = cost / -(regularization * keras.ops.mean(cost) + 1e-16) + # numerical stability: we can subtract the maximum without changing the result + logits = logits - keras.ops.max(logits) + # exponentiate to get the initial transport plan + plan = keras.ops.exp(logits) def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) @@ -90,8 +94,8 @@ def cond(_, plan): def body(steps, plan): # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension - plan = keras.ops.softmax(plan, axis=0) - plan = keras.ops.softmax(plan, axis=1) + plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) + plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) return steps + 1, plan From 7edf36def8e9d79b56f60ec5861deb61e0755451 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 4 Jul 2025 11:22:40 +0200 Subject: [PATCH 02/15] improve test suite --- tests/test_utils/test_optimal_transport.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index dc7b7e93a..21178583d 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -27,13 +27,14 @@ def test_shapes(method): assert keras.ops.shape(oy) == keras.ops.shape(y) -def test_transport_cost_improves(): +@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) +def test_transport_cost_improves(method): x = keras.random.normal((128, 2), seed=0) y = keras.random.normal((128, 2), seed=1) before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) - x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000) + x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000, method=method) after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) @@ -41,14 +42,17 @@ def test_transport_cost_improves(): @pytest.mark.skip(reason="too unreliable") -def test_assignment_is_optimal(): +@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) +def test_assignment_is_optimal(method): x = keras.random.normal((16, 2), seed=0) p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0) optimal_assignments = keras.ops.argsort(p) y = x[p] - x, y, assignments = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=10_000, return_assignments=True) + x, y, assignments = optimal_transport( + x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True + ) assert_allclose(assignments, optimal_assignments) From 5efb00ba676995249470add65c8c85a7a700e82f Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 4 Jul 2025 12:12:02 +0200 Subject: [PATCH 03/15] fix ultra-strict convergence criterion in log_sinkhorn_plan --- bayesflow/utils/optimal_transport/log_sinkhorn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 9fa6dba26..5e1133e52 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -31,8 +31,8 @@ def contains_nans(plan): def is_converged(plan): # for convergence, the plan should be doubly stochastic - conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol)) - conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol)) + conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=0.0, atol=atol + rtol)) + conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=0.0, atol=atol + rtol)) return conv0 & conv1 def cond(_, plan): From 557095a9cbf757aa7f76f3403892db7b9e2a7c98 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 4 Jul 2025 12:12:08 +0200 Subject: [PATCH 04/15] update dependencies --- pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b37a08e15..b7ccc7cb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,16 +36,16 @@ dependencies = [ [project.optional-dependencies] all = [ # dev + "ipython", + "ipykernel", "jupyter", "jupyterlab", + "line-profiler", "nbconvert", - "ipython", - "ipykernel", "pre-commit", "ruff", "tox", # docs - "myst-nb ~= 1.2", "numpydoc ~= 1.8", "pydata-sphinx-theme ~= 0.16", @@ -63,6 +63,7 @@ all = [ dev = [ "jupyter", "jupyterlab", + "line-profiler", "pre-commit", "ruff", "tox", From e1ef07d7913e5e961b74e32e6dbc1de74dd115ac Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 4 Jul 2025 12:15:46 +0200 Subject: [PATCH 05/15] add comment about convergence check --- bayesflow/utils/optimal_transport/log_sinkhorn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 5e1133e52..906d903c5 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -31,6 +31,8 @@ def contains_nans(plan): def is_converged(plan): # for convergence, the plan should be doubly stochastic + # NOTE: for small atol and rtol, using rtol_log=0.0 and atol_log=atol + rtol + # is equivalent to the convergence check in the unstabilized version conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=0.0, atol=atol + rtol)) conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=0.0, atol=atol + rtol)) return conv0 & conv1 From a0e36df364fc19cf7c3c3d7904c2e290389cb8e1 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 4 Jul 2025 12:23:29 +0200 Subject: [PATCH 06/15] update docsting to reflect fixes --- bayesflow/utils/optimal_transport/log_sinkhorn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 906d903c5..4c313dfe9 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -20,7 +20,7 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs): def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None): """ Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`. - Significantly slower than the unstabilized version, so use only when you need numerical stability. + Slightly slower than the unstabilized version, so use primarily when you need numerical stability. """ cost = euclidean(x1, x2) From bbad711bb8b2520229753d00b6af82d5845e8359 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 13:08:23 +0200 Subject: [PATCH 07/15] sinkhorn_plan now returns a transport plan with uniform marginal distributions --- bayesflow/utils/optimal_transport/sinkhorn.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index 466b84e84..e90c331ef 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -71,21 +71,20 @@ def sinkhorn_plan( The transport probabilities. """ cost = euclidean(x1, x2) + cost_scaled = -cost / regularization - # initialize the transport plan from a gaussian kernel - logits = cost / -(regularization * keras.ops.mean(cost) + 1e-16) - # numerical stability: we can subtract the maximum without changing the result - logits = logits - keras.ops.max(logits) - # exponentiate to get the initial transport plan - plan = keras.ops.exp(logits) + # initialize transport plan from a gaussian kernel + # (more numerically stable version of keras.ops.exp(-cost/regularization)) + plan = keras.ops.exp(cost_scaled - keras.ops.max(cost_scaled)) + n, m = keras.ops.shape(cost) def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) def is_converged(plan): - # for convergence, the plan should be doubly stochastic - conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0, rtol=rtol, atol=atol)) - conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0, rtol=rtol, atol=atol)) + # for convergence, the target marginals must match + conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0 / m, rtol=rtol, atol=atol)) + conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0 / n, rtol=rtol, atol=atol)) return conv0 & conv1 def cond(_, plan): @@ -94,8 +93,8 @@ def cond(_, plan): def body(steps, plan): # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension - plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) - plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) + plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) * (1.0 / m) + plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) * (1.0 / n) return steps + 1, plan From 83b01c9d82bc09aba7201e204b4db0062aa87c34 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 13:57:42 +0200 Subject: [PATCH 08/15] add unit test for sinkhorn_plan --- tests/test_utils/test_optimal_transport.py | 48 ++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index 21178583d..dcd840cea 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -79,3 +79,51 @@ def test_assignment_aligns_with_pot(): _, _, assignments = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True) assert_allclose(pot_assignments, assignments) + + +def test_sinkhorn_plan_correct_marginals(): + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan + + x1 = keras.random.normal((10, 2), seed=0) + x2 = keras.random.normal((20, 2), seed=1) + + assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=0), 0.05, atol=1e-6)) + assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=1), 0.1, atol=1e-6)) + + +def test_sinkhorn_plan_aligns_with_pot(): + try: + from ot.bregman import sinkhorn + except (ImportError, ModuleNotFoundError): + pytest.skip("Need to install POT to run this test.") + + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan + from bayesflow.utils.optimal_transport.euclidean import euclidean + + x1 = keras.random.normal((10, 3), seed=0) + x2 = keras.random.normal((20, 3), seed=1) + + a = keras.ops.ones(10) / 10 + b = keras.ops.ones(20) / 20 + M = euclidean(x1, x2) + + pot_result = sinkhorn(a, b, M, 0.1) + our_result = sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-8, rtol=1e-8) + assert_allclose(pot_result, our_result) + + +def test_sinkhorn_plan_matches_analytical_result(): + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan + + x1 = keras.ops.ones(16) + x2 = keras.ops.ones(64) + + marginal_x1 = keras.ops.ones(16) / 16 + marginal_x2 = keras.ops.ones(64) / 64 + + result = sinkhorn_plan(x1, x2, regularization=0.1) + + # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals + expected = keras.ops.outer(marginal_x1, marginal_x2) + + assert_allclose(result, expected) From 4502a365880ea0f74c7eaf5702702c07360c1e16 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 14:33:22 +0200 Subject: [PATCH 09/15] fix sinkhorn function by sampling from the logits of the transpose of the plan, instead of the plan directly --- bayesflow/utils/optimal_transport/sinkhorn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index e90c331ef..0fb0ce73b 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -11,7 +11,7 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten """ Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm. - Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a doubly stochastic + Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a transport plan, containing assignment probabilities. The permutation is then sampled randomly according to the transport plan. @@ -31,8 +31,10 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten Assignment indices for x2. """ - plan = sinkhorn_plan(x1, x2, **kwargs) - assignments = keras.random.categorical(plan, num_samples=1, seed=seed) + plan = sinkhorn_plan(x1, x2, **kwargs) # shape: (n, m) + + # we sample from plan.T to receive assignments of length m, with elements up to n + assignments = keras.random.categorical(keras.ops.log(plan.T), num_samples=1, seed=seed) assignments = keras.ops.squeeze(assignments, axis=1) return assignments From 286bea6143661f0e97d3a74c4fa2a17b56b3de7a Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 17:59:04 +0200 Subject: [PATCH 10/15] sinkhorn(x1, x2) now samples from log(plan) to receive assignments such that x2[assignments] matches x1 --- bayesflow/utils/optimal_transport/sinkhorn.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index 0fb0ce73b..45c568294 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -27,14 +27,15 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten :param seed: Random seed to use for sampling indices. Default: None, which means the seed will be auto-determined for non-compiled contexts. - :return: Tensor of shape (m,) + :return: Tensor of shape (n,) Assignment indices for x2. """ - plan = sinkhorn_plan(x1, x2, **kwargs) # shape: (n, m) + plan = sinkhorn_plan(x1, x2, **kwargs) - # we sample from plan.T to receive assignments of length m, with elements up to n - assignments = keras.random.categorical(keras.ops.log(plan.T), num_samples=1, seed=seed) + # we sample from log(plan) to receive assignments of length n, corresponding to indices of x2 + # such that x2[assignments] matches x1 + assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed) assignments = keras.ops.squeeze(assignments, axis=1) return assignments From 6ee5244a504c8deda9af4911b4a01cb154dfe3b4 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 18:00:55 +0200 Subject: [PATCH 11/15] re-enable test_assignment_is_optimal() for method='sinkhorn' --- tests/test_utils/test_optimal_transport.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index dcd840cea..7d4645440 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -41,20 +41,20 @@ def test_transport_cost_improves(method): assert after_cost < before_cost -@pytest.mark.skip(reason="too unreliable") -@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) +# @pytest.mark.skip(reason="too unreliable") +@pytest.mark.parametrize("method", ["sinkhorn"]) def test_assignment_is_optimal(method): - x = keras.random.normal((16, 2), seed=0) - p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0) - optimal_assignments = keras.ops.argsort(p) + y = keras.random.normal((16, 2), seed=0) + p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0) - y = x[p] + x = y[p] - x, y, assignments = optimal_transport( + _, _, assignments = optimal_transport( x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True ) - assert_allclose(assignments, optimal_assignments) + # transport is stochastic, so it is expected that a small fraction of assignments do not match + assert keras.ops.sum(assignments == p) > 14 def test_assignment_aligns_with_pot(): @@ -109,6 +109,7 @@ def test_sinkhorn_plan_aligns_with_pot(): pot_result = sinkhorn(a, b, M, 0.1) our_result = sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-8, rtol=1e-8) + assert_allclose(pot_result, our_result) From 02c24c1dd90811c3a0b38691efe4604067eecf0d Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 18:29:28 +0200 Subject: [PATCH 12/15] log_sinkhorn now correctly uses log_plan instead of keras.ops.exp(log_plan), log_sinkhorn_plan returns logits of the transport plan --- .../utils/optimal_transport/log_sinkhorn.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 4c313dfe9..2def2b0c7 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -8,10 +8,10 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs): """ Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`. - Significantly slower than the unstabilized version, so use only when you need numerical stability. + About 50% slower than the unstabilized version, so use only when you need numerical stability. """ log_plan = log_sinkhorn_plan(x1, x2, **kwargs) - assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed) + assignments = keras.random.categorical(log_plan, num_samples=1, seed=seed) assignments = keras.ops.squeeze(assignments, axis=1) return assignments @@ -20,21 +20,25 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs): def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None): """ Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`. - Slightly slower than the unstabilized version, so use primarily when you need numerical stability. + About 50% slower than the unstabilized version, so use primarily when you need numerical stability. """ cost = euclidean(x1, x2) + cost_scaled = -cost / regularization - log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16) + # initialize transport plan from a gaussian kernel + log_plan = cost_scaled - keras.ops.max(cost_scaled) + n, m = keras.ops.shape(log_plan) + + log_a = -keras.ops.log(n) + log_b = -keras.ops.log(m) def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) def is_converged(plan): - # for convergence, the plan should be doubly stochastic - # NOTE: for small atol and rtol, using rtol_log=0.0 and atol_log=atol + rtol - # is equivalent to the convergence check in the unstabilized version - conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=0.0, atol=atol + rtol)) - conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=0.0, atol=atol + rtol)) + # for convergence, the target marginals must match + conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), log_b, rtol=0.0, atol=rtol + atol)) + conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), log_a, rtol=0.0, atol=rtol + atol)) return conv0 & conv1 def cond(_, plan): @@ -43,8 +47,8 @@ def cond(_, plan): def body(steps, plan): # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension - plan = keras.ops.log_softmax(plan, axis=0) - plan = keras.ops.log_softmax(plan, axis=1) + plan = plan - keras.ops.logsumexp(plan, axis=0, keepdims=True) + log_b + plan = plan - keras.ops.logsumexp(plan, axis=1, keepdims=True) + log_a return steps + 1, plan From f25df63bf1200b16be91db2e2cef1f83e198d421 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 19:32:16 +0200 Subject: [PATCH 13/15] add unit tests for log_sinkhorn_plan --- tests/test_utils/test_optimal_transport.py | 64 ++++++++++++++++++++-- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index 7d4645440..837d600e4 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -41,8 +41,7 @@ def test_transport_cost_improves(method): assert after_cost < before_cost -# @pytest.mark.skip(reason="too unreliable") -@pytest.mark.parametrize("method", ["sinkhorn"]) +@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) def test_assignment_is_optimal(method): y = keras.random.normal((16, 2), seed=0) p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0) @@ -72,8 +71,8 @@ def test_assignment_aligns_with_pot(): M = x[:, None] - y[None, :] M = keras.ops.norm(M, axis=-1) - pot_plan = sinkhorn_log(a, b, M, reg=1e-3, numItermax=10_000, stopThr=1e-99) - pot_assignments = keras.random.categorical(pot_plan, num_samples=1, seed=0) + pot_plan = sinkhorn_log(a, b, M, reg=1e-3, stopThr=1e-7) + pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0) pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1) _, _, assignments = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True) @@ -107,8 +106,8 @@ def test_sinkhorn_plan_aligns_with_pot(): b = keras.ops.ones(20) / 20 M = euclidean(x1, x2) - pot_result = sinkhorn(a, b, M, 0.1) - our_result = sinkhorn_plan(x1, x2, regularization=0.1, atol=1e-8, rtol=1e-8) + pot_result = sinkhorn(a, b, M, 0.1, stopThr=1e-8) + our_result = sinkhorn_plan(x1, x2, regularization=0.1, rtol=1e-7) assert_allclose(pot_result, our_result) @@ -128,3 +127,56 @@ def test_sinkhorn_plan_matches_analytical_result(): expected = keras.ops.outer(marginal_x1, marginal_x2) assert_allclose(result, expected) + + +def test_log_sinkhorn_plan_correct_marginals(): + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan + + x1 = keras.random.normal((10, 2), seed=0) + x2 = keras.random.normal((20, 2), seed=1) + + assert keras.ops.all( + keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=0), -keras.ops.log(20), atol=1e-3) + ) + assert keras.ops.all( + keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=1), -keras.ops.log(10), atol=1e-3) + ) + + +def test_log_sinkhorn_plan_aligns_with_pot(): + try: + from ot.bregman import sinkhorn_log + except (ImportError, ModuleNotFoundError): + pytest.skip("Need to install POT to run this test.") + + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan + from bayesflow.utils.optimal_transport.euclidean import euclidean + + x1 = keras.random.normal((100, 3), seed=0) + x2 = keras.random.normal((200, 3), seed=1) + + a = keras.ops.ones(100) / 100 + b = keras.ops.ones(200) / 200 + M = euclidean(x1, x2) + + pot_result = keras.ops.log(sinkhorn_log(a, b, M, 0.1, stopThr=1e-7)) # sinkhorn_log returns probabilities + our_result = log_sinkhorn_plan(x1, x2, regularization=0.1) + + assert_allclose(pot_result, our_result) + + +def test_log_sinkhorn_plan_matches_analytical_result(): + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan + + x1 = keras.ops.ones(16) + x2 = keras.ops.ones(64) + + marginal_x1 = keras.ops.ones(16) / 16 + marginal_x2 = keras.ops.ones(64) / 64 + + result = keras.ops.exp(log_sinkhorn_plan(x1, x2, regularization=0.1)) + + # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals + expected = keras.ops.outer(marginal_x1, marginal_x2) + + assert_allclose(result, expected) From 8f7ddb27d418dd214ae783b9edeecff06a80f2dd Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 5 Jul 2025 22:07:10 +0200 Subject: [PATCH 14/15] fix faulty indexing with tensor for tensorflow backend --- tests/test_utils/test_optimal_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index 837d600e4..9d6021a8a 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -46,7 +46,7 @@ def test_assignment_is_optimal(method): y = keras.random.normal((16, 2), seed=0) p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0) - x = y[p] + x = keras.ops.take(y, p, axis=0) _, _, assignments = optimal_transport( x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True From b0f71b984e479198e45270b205d169e3424d50c8 Mon Sep 17 00:00:00 2001 From: larskue Date: Tue, 8 Jul 2025 14:33:53 +0200 Subject: [PATCH 15/15] re-add numItermax for ot pot test --- tests/test_utils/test_optimal_transport.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index 9d6021a8a..53a5fd7a6 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -61,6 +61,7 @@ def test_assignment_aligns_with_pot(): from ot.bregman import sinkhorn_log except (ImportError, ModuleNotFoundError): pytest.skip("Need to install POT to run this test.") + return x = keras.random.normal((16, 2), seed=0) p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0) @@ -71,7 +72,7 @@ def test_assignment_aligns_with_pot(): M = x[:, None] - y[None, :] M = keras.ops.norm(M, axis=-1) - pot_plan = sinkhorn_log(a, b, M, reg=1e-3, stopThr=1e-7) + pot_plan = sinkhorn_log(a, b, M, numItermax=10_000, reg=1e-3, stopThr=1e-7) pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0) pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1)