Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions bayesflow/utils/optimal_transport/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def sinkhorn_plan(
x1: Tensor,
x2: Tensor,
regularization: float = 1.0,
max_steps: int = 10_000,
max_steps: int = None,
rtol: float = 1e-5,
atol: float = 1e-8,
) -> Tensor:
Expand All @@ -59,7 +59,7 @@ def sinkhorn_plan(
Controls the standard deviation of the Gaussian kernel.

:param max_steps: Maximum number of iterations, or None to run until convergence.
Default: 10_000
Default: None

:param rtol: Relative tolerance for convergence.
Default: 1e-5.
Expand All @@ -73,7 +73,11 @@ def sinkhorn_plan(
cost = euclidean(x1, x2)

# initialize the transport plan from a gaussian kernel
plan = keras.ops.exp(cost / -(regularization * keras.ops.mean(cost) + 1e-16))
logits = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
# numerical stability: we can subtract the maximum without changing the result
logits = logits - keras.ops.max(logits)
# exponentiate to get the initial transport plan
plan = keras.ops.exp(logits)

def contains_nans(plan):
return keras.ops.any(keras.ops.isnan(plan))
Expand All @@ -90,8 +94,8 @@ def cond(_, plan):

def body(steps, plan):
# Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension
plan = keras.ops.softmax(plan, axis=0)
plan = keras.ops.softmax(plan, axis=1)
plan = plan / keras.ops.sum(plan, axis=0, keepdims=True)
plan = plan / keras.ops.sum(plan, axis=1, keepdims=True)

return steps + 1, plan

Expand Down
12 changes: 8 additions & 4 deletions tests/test_utils/test_optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,32 @@ def test_shapes(method):
assert keras.ops.shape(oy) == keras.ops.shape(y)


def test_transport_cost_improves():
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
def test_transport_cost_improves(method):
x = keras.random.normal((128, 2), seed=0)
y = keras.random.normal((128, 2), seed=1)

before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))

x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000)
x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=1000, method=method)

after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))

assert after_cost < before_cost


@pytest.mark.skip(reason="too unreliable")
def test_assignment_is_optimal():
@pytest.mark.parametrize("method", ["log_sinkhorn", "sinkhorn"])
def test_assignment_is_optimal(method):
x = keras.random.normal((16, 2), seed=0)
p = keras.random.shuffle(keras.ops.arange(keras.ops.shape(x)[0]), seed=0)
optimal_assignments = keras.ops.argsort(p)

y = x[p]

x, y, assignments = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=10_000, return_assignments=True)
x, y, assignments = optimal_transport(
x, y, regularization=0.1, seed=0, max_steps=10_000, method=method, return_assignments=True
)

assert_allclose(assignments, optimal_assignments)

Expand Down