diff --git a/bayesflow/utils/optimal_transport/log_sinkhorn.py b/bayesflow/utils/optimal_transport/log_sinkhorn.py index 9fa6dba26..2def2b0c7 100644 --- a/bayesflow/utils/optimal_transport/log_sinkhorn.py +++ b/bayesflow/utils/optimal_transport/log_sinkhorn.py @@ -8,10 +8,10 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs): """ Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`. - Significantly slower than the unstabilized version, so use only when you need numerical stability. + About 50% slower than the unstabilized version, so use only when you need numerical stability. """ log_plan = log_sinkhorn_plan(x1, x2, **kwargs) - assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed) + assignments = keras.random.categorical(log_plan, num_samples=1, seed=seed) assignments = keras.ops.squeeze(assignments, axis=1) return assignments @@ -20,19 +20,25 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs): def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None): """ Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`. - Significantly slower than the unstabilized version, so use only when you need numerical stability. + About 50% slower than the unstabilized version, so use primarily when you need numerical stability. """ cost = euclidean(x1, x2) + cost_scaled = -cost / regularization - log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16) + # initialize transport plan from a gaussian kernel + log_plan = cost_scaled - keras.ops.max(cost_scaled) + n, m = keras.ops.shape(log_plan) + + log_a = -keras.ops.log(n) + log_b = -keras.ops.log(m) def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) def is_converged(plan): - # for convergence, the plan should be doubly stochastic - conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol)) - conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol)) + # for convergence, the target marginals must match + conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), log_b, rtol=0.0, atol=rtol + atol)) + conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), log_a, rtol=0.0, atol=rtol + atol)) return conv0 & conv1 def cond(_, plan): @@ -41,8 +47,8 @@ def cond(_, plan): def body(steps, plan): # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension - plan = keras.ops.log_softmax(plan, axis=0) - plan = keras.ops.log_softmax(plan, axis=1) + plan = plan - keras.ops.logsumexp(plan, axis=0, keepdims=True) + log_b + plan = plan - keras.ops.logsumexp(plan, axis=1, keepdims=True) + log_a return steps + 1, plan diff --git a/bayesflow/utils/optimal_transport/sinkhorn.py b/bayesflow/utils/optimal_transport/sinkhorn.py index f7e0ba835..45c568294 100644 --- a/bayesflow/utils/optimal_transport/sinkhorn.py +++ b/bayesflow/utils/optimal_transport/sinkhorn.py @@ -11,7 +11,7 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten """ Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm. - Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a doubly stochastic + Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a transport plan, containing assignment probabilities. The permutation is then sampled randomly according to the transport plan. @@ -27,12 +27,15 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten :param seed: Random seed to use for sampling indices. Default: None, which means the seed will be auto-determined for non-compiled contexts. - :return: Tensor of shape (m,) + :return: Tensor of shape (n,) Assignment indices for x2. """ plan = sinkhorn_plan(x1, x2, **kwargs) - assignments = keras.random.categorical(plan, num_samples=1, seed=seed) + + # we sample from log(plan) to receive assignments of length n, corresponding to indices of x2 + # such that x2[assignments] matches x1 + assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed) assignments = keras.ops.squeeze(assignments, axis=1) return assignments @@ -42,7 +45,7 @@ def sinkhorn_plan( x1: Tensor, x2: Tensor, regularization: float = 1.0, - max_steps: int = 10_000, + max_steps: int = None, rtol: float = 1e-5, atol: float = 1e-8, ) -> Tensor: @@ -59,7 +62,7 @@ def sinkhorn_plan( Controls the standard deviation of the Gaussian kernel. :param max_steps: Maximum number of iterations, or None to run until convergence. - Default: 10_000 + Default: None :param rtol: Relative tolerance for convergence. Default: 1e-5. @@ -71,17 +74,20 @@ def sinkhorn_plan( The transport probabilities. """ cost = euclidean(x1, x2) + cost_scaled = -cost / regularization - # initialize the transport plan from a gaussian kernel - plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16)) + # initialize transport plan from a gaussian kernel + # (more numerically stable version of keras.ops.exp(-cost/regularization)) + plan = keras.ops.exp(cost_scaled - keras.ops.max(cost_scaled)) + n, m = keras.ops.shape(cost) def contains_nans(plan): return keras.ops.any(keras.ops.isnan(plan)) def is_converged(plan): - # for convergence, the plan should be doubly stochastic - conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0, rtol=rtol, atol=atol)) - conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0, rtol=rtol, atol=atol)) + # for convergence, the target marginals must match + conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0 / m, rtol=rtol, atol=atol)) + conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0 / n, rtol=rtol, atol=atol)) return conv0 & conv1 def cond(_, plan): @@ -90,8 +96,8 @@ def cond(_, plan): def body(steps, plan): # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension - plan = keras.ops.softmax(plan, axis=0) - plan = keras.ops.softmax(plan, axis=1) + plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) * (1.0 / m) + plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) * (1.0 / n) return steps + 1, plan diff --git a/pyproject.toml b/pyproject.toml index 902ffc1d2..f29938bba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,16 +36,16 @@ dependencies = [ [project.optional-dependencies] all = [ # dev + "ipython", + "ipykernel", "jupyter", "jupyterlab", + "line-profiler", "nbconvert", - "ipython", - "ipykernel", "pre-commit", "ruff", "tox", # docs - "myst-nb ~= 1.2", "numpydoc ~= 1.8", "pydata-sphinx-theme ~= 0.16", @@ -63,6 +63,7 @@ all = [ dev = [ "jupyter", "jupyterlab", + "line-profiler", "pre-commit", "ruff", "tox", diff --git a/tests/test_utils/test_optimal_transport.py b/tests/test_utils/test_optimal_transport.py index dc7b7e93a..53a5fd7a6 100644 --- a/tests/test_utils/test_optimal_transport.py +++ b/tests/test_utils/test_optimal_transport.py @@ -27,30 +27,33 @@ def test_shapes(method): assert keras.ops.shape(oy) == keras.ops.shape(y) -def test_transport_cost_improves(): +@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) +def test_transport_cost_improves(method): x = keras.random.normal((128, 2), seed=0) y = keras.random.normal((128, 2), seed=1) before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) - x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000) + x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000, method=method) after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) assert after_cost < before_cost -@pytest.mark.skip(reason="too unreliable") -def test_assignment_is_optimal(): - x = keras.random.normal((16, 2), seed=0) - p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0) - optimal_assignments = keras.ops.argsort(p) +@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"]) +def test_assignment_is_optimal(method): + y = keras.random.normal((16, 2), seed=0) + p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0) - y = x[p] + x = keras.ops.take(y, p, axis=0) - x, y, assignments = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=10_000, return_assignments=True) + _, _, assignments = optimal_transport( + x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True + ) - assert_allclose(assignments, optimal_assignments) + # transport is stochastic, so it is expected that a small fraction of assignments do not match + assert keras.ops.sum(assignments == p) > 14 def test_assignment_aligns_with_pot(): @@ -58,6 +61,7 @@ def test_assignment_aligns_with_pot(): from ot.bregman import sinkhorn_log except (ImportError, ModuleNotFoundError): pytest.skip("Need to install POT to run this test.") + return x = keras.random.normal((16, 2), seed=0) p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0) @@ -68,10 +72,112 @@ def test_assignment_aligns_with_pot(): M = x[:, None] - y[None, :] M = keras.ops.norm(M, axis=-1) - pot_plan = sinkhorn_log(a, b, M, reg=1e-3, numItermax=10_000, stopThr=1e-99) - pot_assignments = keras.random.categorical(pot_plan, num_samples=1, seed=0) + pot_plan = sinkhorn_log(a, b, M, numItermax=10_000, reg=1e-3, stopThr=1e-7) + pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0) pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1) _, _, assignments = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True) assert_allclose(pot_assignments, assignments) + + +def test_sinkhorn_plan_correct_marginals(): + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan + + x1 = keras.random.normal((10, 2), seed=0) + x2 = keras.random.normal((20, 2), seed=1) + + assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=0), 0.05, atol=1e-6)) + assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=1), 0.1, atol=1e-6)) + + +def test_sinkhorn_plan_aligns_with_pot(): + try: + from ot.bregman import sinkhorn + except (ImportError, ModuleNotFoundError): + pytest.skip("Need to install POT to run this test.") + + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan + from bayesflow.utils.optimal_transport.euclidean import euclidean + + x1 = keras.random.normal((10, 3), seed=0) + x2 = keras.random.normal((20, 3), seed=1) + + a = keras.ops.ones(10) / 10 + b = keras.ops.ones(20) / 20 + M = euclidean(x1, x2) + + pot_result = sinkhorn(a, b, M, 0.1, stopThr=1e-8) + our_result = sinkhorn_plan(x1, x2, regularization=0.1, rtol=1e-7) + + assert_allclose(pot_result, our_result) + + +def test_sinkhorn_plan_matches_analytical_result(): + from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan + + x1 = keras.ops.ones(16) + x2 = keras.ops.ones(64) + + marginal_x1 = keras.ops.ones(16) / 16 + marginal_x2 = keras.ops.ones(64) / 64 + + result = sinkhorn_plan(x1, x2, regularization=0.1) + + # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals + expected = keras.ops.outer(marginal_x1, marginal_x2) + + assert_allclose(result, expected) + + +def test_log_sinkhorn_plan_correct_marginals(): + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan + + x1 = keras.random.normal((10, 2), seed=0) + x2 = keras.random.normal((20, 2), seed=1) + + assert keras.ops.all( + keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=0), -keras.ops.log(20), atol=1e-3) + ) + assert keras.ops.all( + keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=1), -keras.ops.log(10), atol=1e-3) + ) + + +def test_log_sinkhorn_plan_aligns_with_pot(): + try: + from ot.bregman import sinkhorn_log + except (ImportError, ModuleNotFoundError): + pytest.skip("Need to install POT to run this test.") + + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan + from bayesflow.utils.optimal_transport.euclidean import euclidean + + x1 = keras.random.normal((100, 3), seed=0) + x2 = keras.random.normal((200, 3), seed=1) + + a = keras.ops.ones(100) / 100 + b = keras.ops.ones(200) / 200 + M = euclidean(x1, x2) + + pot_result = keras.ops.log(sinkhorn_log(a, b, M, 0.1, stopThr=1e-7)) # sinkhorn_log returns probabilities + our_result = log_sinkhorn_plan(x1, x2, regularization=0.1) + + assert_allclose(pot_result, our_result) + + +def test_log_sinkhorn_plan_matches_analytical_result(): + from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan + + x1 = keras.ops.ones(16) + x2 = keras.ops.ones(64) + + marginal_x1 = keras.ops.ones(16) / 16 + marginal_x2 = keras.ops.ones(64) / 64 + + result = keras.ops.exp(log_sinkhorn_plan(x1, x2, regularization=0.1)) + + # If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals + expected = keras.ops.outer(marginal_x1, marginal_x2) + + assert_allclose(result, expected)