Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions bayesflow/utils/optimal_transport/log_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
"""
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`.
Significantly slower than the unstabilized version, so use only when you need numerical stability.
About 50% slower than the unstabilized version, so use only when you need numerical stability.
"""
log_plan = log_sinkhorn_plan(x1, x2, **kwargs)
assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed)
assignments = keras.random.categorical(log_plan, num_samples=1, seed=seed)
assignments = keras.ops.squeeze(assignments, axis=1)

return assignments
Expand All @@ -20,19 +20,25 @@ def log_sinkhorn(x1, x2, seed: int = None, **kwargs):
def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None):
"""
Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`.
Significantly slower than the unstabilized version, so use only when you need numerical stability.
About 50% slower than the unstabilized version, so use primarily when you need numerical stability.
"""
cost = euclidean(x1, x2)
cost_scaled = -cost / regularization

log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
# initialize transport plan from a gaussian kernel
log_plan = cost_scaled - keras.ops.max(cost_scaled)
n, m = keras.ops.shape(log_plan)

log_a = -keras.ops.log(n)
log_b = -keras.ops.log(m)

def contains_nans(plan):
return keras.ops.any(keras.ops.isnan(plan))

def is_converged(plan):
# for convergence, the plan should be doubly stochastic
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol))
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol))
# for convergence, the target marginals must match
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), log_b, rtol=0.0, atol=rtol + atol))
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), log_a, rtol=0.0, atol=rtol + atol))
return conv0 & conv1

def cond(_, plan):
Expand All @@ -41,8 +47,8 @@ def cond(_, plan):

def body(steps, plan):
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
plan = keras.ops.log_softmax(plan, axis=0)
plan = keras.ops.log_softmax(plan, axis=1)
plan = plan - keras.ops.logsumexp(plan, axis=0, keepdims=True) + log_b
plan = plan - keras.ops.logsumexp(plan, axis=1, keepdims=True) + log_a

return steps + 1, plan

Expand Down
30 changes: 18 additions & 12 deletions bayesflow/utils/optimal_transport/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
"""
Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.

Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a doubly stochastic
Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a
transport plan, containing assignment probabilities.
The permutation is then sampled randomly according to the transport plan.

Expand All @@ -27,12 +27,15 @@ def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Ten
:param seed: Random seed to use for sampling indices.
Default: None, which means the seed will be auto-determined for non-compiled contexts.

:return: Tensor of shape (m,)
:return: Tensor of shape (n,)
Assignment indices for x2.

"""
plan = sinkhorn_plan(x1, x2, **kwargs)
assignments = keras.random.categorical(plan, num_samples=1, seed=seed)

# we sample from log(plan) to receive assignments of length n, corresponding to indices of x2
# such that x2[assignments] matches x1
assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed)
assignments = keras.ops.squeeze(assignments, axis=1)

return assignments
Expand All @@ -42,7 +45,7 @@ def sinkhorn_plan(
x1: Tensor,
x2: Tensor,
regularization: float = 1.0,
max_steps: int = 10_000,
max_steps: int = None,
rtol: float = 1e-5,
atol: float = 1e-8,
) -> Tensor:
Expand All @@ -59,7 +62,7 @@ def sinkhorn_plan(
Controls the standard deviation of the Gaussian kernel.

:param max_steps: Maximum number of iterations, or None to run until convergence.
Default: 10_000
Default: None

:param rtol: Relative tolerance for convergence.
Default: 1e-5.
Expand All @@ -71,17 +74,20 @@ def sinkhorn_plan(
The transport probabilities.
"""
cost = euclidean(x1, x2)
cost_scaled = -cost / regularization

# initialize the transport plan from a gaussian kernel
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))
# initialize transport plan from a gaussian kernel
# (more numerically stable version of keras.ops.exp(-cost/regularization))
plan = keras.ops.exp(cost_scaled - keras.ops.max(cost_scaled))
n, m = keras.ops.shape(cost)

def contains_nans(plan):
return keras.ops.any(keras.ops.isnan(plan))

def is_converged(plan):
# for convergence, the plan should be doubly stochastic
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0, rtol=rtol, atol=atol))
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0, rtol=rtol, atol=atol))
# for convergence, the target marginals must match
conv0 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=0), 1.0 / m, rtol=rtol, atol=atol))
conv1 = keras.ops.all(keras.ops.isclose(keras.ops.sum(plan, axis=1), 1.0 / n, rtol=rtol, atol=atol))
return conv0 & conv1

def cond(_, plan):
Expand All @@ -90,8 +96,8 @@ def cond(_, plan):

def body(steps, plan):
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
plan = keras.ops.softmax(plan, axis=0)
plan = keras.ops.softmax(plan, axis=1)
plan = plan / keras.ops.sum(plan, axis=0, keepdims=True) * (1.0 / m)
plan = plan / keras.ops.sum(plan, axis=1, keepdims=True) * (1.0 / n)

return steps + 1, plan

Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ dependencies = [
[project.optional-dependencies]
all = [
# dev
"ipython",
"ipykernel",
"jupyter",
"jupyterlab",
"line-profiler",
"nbconvert",
"ipython",
"ipykernel",
"pre-commit",
"ruff",
"tox",
# docs

"myst-nb ~= 1.2",
"numpydoc ~= 1.8",
"pydata-sphinx-theme ~= 0.16",
Expand All @@ -63,6 +63,7 @@ all = [
dev = [
"jupyter",
"jupyterlab",
"line-profiler",
"pre-commit",
"ruff",
"tox",
Expand Down
130 changes: 118 additions & 12 deletions tests/test_utils/test_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,41 @@ def test_shapes(method):
assert keras.ops.shape(oy) == keras.ops.shape(y)


def test_transport_cost_improves():
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
def test_transport_cost_improves(method):
x = keras.random.normal((128, 2), seed=0)
y = keras.random.normal((128, 2), seed=1)

before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))

x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000)
x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000, method=method)

after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))

assert after_cost < before_cost


@pytest.mark.skip(reason="too unreliable")
def test_assignment_is_optimal():
x = keras.random.normal((16, 2), seed=0)
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0)
optimal_assignments = keras.ops.argsort(p)
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
def test_assignment_is_optimal(method):
y = keras.random.normal((16, 2), seed=0)
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(y)[0]), seed=0)

y = x[p]
x = keras.ops.take(y, p, axis=0)

x, y, assignments = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=10_000, return_assignments=True)
_, _, assignments = optimal_transport(
x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True
)

assert_allclose(assignments, optimal_assignments)
# transport is stochastic, so it is expected that a small fraction of assignments do not match
assert keras.ops.sum(assignments == p) > 14


def test_assignment_aligns_with_pot():
try:
from ot.bregman import sinkhorn_log
except (ImportError, ModuleNotFoundError):
pytest.skip("Need to install POT to run this test.")
return

x = keras.random.normal((16, 2), seed=0)
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0)
Expand All @@ -68,10 +72,112 @@ def test_assignment_aligns_with_pot():
M = x[:, None] - y[None, :]
M = keras.ops.norm(M, axis=-1)

pot_plan = sinkhorn_log(a, b, M, reg=1e-3, numItermax=10_000, stopThr=1e-99)
pot_assignments = keras.random.categorical(pot_plan, num_samples=1, seed=0)
pot_plan = sinkhorn_log(a, b, M, numItermax=10_000, reg=1e-3, stopThr=1e-7)
pot_assignments = keras.random.categorical(keras.ops.log(pot_plan), num_samples=1, seed=0)
pot_assignments = keras.ops.squeeze(pot_assignments, axis=-1)

_, _, assignments = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=10_000, return_assignments=True)

assert_allclose(pot_assignments, assignments)


def test_sinkhorn_plan_correct_marginals():
from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan

x1 = keras.random.normal((10, 2), seed=0)
x2 = keras.random.normal((20, 2), seed=1)

assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=0), 0.05, atol=1e-6))
assert keras.ops.all(keras.ops.isclose(keras.ops.sum(sinkhorn_plan(x1, x2), axis=1), 0.1, atol=1e-6))


def test_sinkhorn_plan_aligns_with_pot():
try:
from ot.bregman import sinkhorn
except (ImportError, ModuleNotFoundError):
pytest.skip("Need to install POT to run this test.")

from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan
from bayesflow.utils.optimal_transport.euclidean import euclidean

x1 = keras.random.normal((10, 3), seed=0)
x2 = keras.random.normal((20, 3), seed=1)

a = keras.ops.ones(10) / 10
b = keras.ops.ones(20) / 20
M = euclidean(x1, x2)

pot_result = sinkhorn(a, b, M, 0.1, stopThr=1e-8)
our_result = sinkhorn_plan(x1, x2, regularization=0.1, rtol=1e-7)

assert_allclose(pot_result, our_result)


def test_sinkhorn_plan_matches_analytical_result():
from bayesflow.utils.optimal_transport.sinkhorn import sinkhorn_plan

x1 = keras.ops.ones(16)
x2 = keras.ops.ones(64)

marginal_x1 = keras.ops.ones(16) / 16
marginal_x2 = keras.ops.ones(64) / 64

result = sinkhorn_plan(x1, x2, regularization=0.1)

# If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
expected = keras.ops.outer(marginal_x1, marginal_x2)

assert_allclose(result, expected)


def test_log_sinkhorn_plan_correct_marginals():
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan

x1 = keras.random.normal((10, 2), seed=0)
x2 = keras.random.normal((20, 2), seed=1)

assert keras.ops.all(
keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=0), -keras.ops.log(20), atol=1e-3)
)
assert keras.ops.all(
keras.ops.isclose(keras.ops.logsumexp(log_sinkhorn_plan(x1, x2), axis=1), -keras.ops.log(10), atol=1e-3)
)


def test_log_sinkhorn_plan_aligns_with_pot():
try:
from ot.bregman import sinkhorn_log
except (ImportError, ModuleNotFoundError):
pytest.skip("Need to install POT to run this test.")

from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan
from bayesflow.utils.optimal_transport.euclidean import euclidean

x1 = keras.random.normal((100, 3), seed=0)
x2 = keras.random.normal((200, 3), seed=1)

a = keras.ops.ones(100) / 100
b = keras.ops.ones(200) / 200
M = euclidean(x1, x2)

pot_result = keras.ops.log(sinkhorn_log(a, b, M, 0.1, stopThr=1e-7)) # sinkhorn_log returns probabilities
our_result = log_sinkhorn_plan(x1, x2, regularization=0.1)

assert_allclose(pot_result, our_result)


def test_log_sinkhorn_plan_matches_analytical_result():
from bayesflow.utils.optimal_transport.log_sinkhorn import log_sinkhorn_plan

x1 = keras.ops.ones(16)
x2 = keras.ops.ones(64)

marginal_x1 = keras.ops.ones(16) / 16
marginal_x2 = keras.ops.ones(64) / 64

result = keras.ops.exp(log_sinkhorn_plan(x1, x2, regularization=0.1))

# If x1 and x2 are identical, the optimal plan is simply the outer product of the marginals
expected = keras.ops.outer(marginal_x1, marginal_x2)

assert_allclose(result, expected)