|
| 1 | +import keras |
| 2 | + |
| 3 | +from .. import logging |
| 4 | +from ..tensor_utils import is_symbolic_tensor |
| 5 | + |
| 6 | +from .euclidean import euclidean |
| 7 | + |
| 8 | + |
| 9 | +def log_sinkhorn(x1, x2, seed: int = None, **kwargs): |
| 10 | + """ |
| 11 | + Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn`. |
| 12 | + Significantly slower than the unstabilized version, so use only when you need numerical stability. |
| 13 | + """ |
| 14 | + log_plan = log_sinkhorn_plan(x1, x2, **kwargs) |
| 15 | + assignments = keras.random.categorical(keras.ops.exp(log_plan), num_samples=1, seed=seed) |
| 16 | + assignments = keras.ops.squeeze(assignments, axis=1) |
| 17 | + |
| 18 | + return assignments |
| 19 | + |
| 20 | + |
| 21 | +def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8, max_steps=None): |
| 22 | + """ |
| 23 | + Log-stabilized version of :py:func:`~bayesflow.utils.optimal_transport.sinkhorn.sinkhorn_plan`. |
| 24 | + Significantly slower than the unstabilized version, so use only when you need numerical stability. |
| 25 | + """ |
| 26 | + cost = euclidean(x1, x2) |
| 27 | + |
| 28 | + log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16) |
| 29 | + |
| 30 | + if is_symbolic_tensor(log_plan): |
| 31 | + return log_plan |
| 32 | + |
| 33 | + def contains_nans(plan): |
| 34 | + return keras.ops.any(keras.ops.isnan(plan)) |
| 35 | + |
| 36 | + def is_converged(plan): |
| 37 | + # for convergence, the plan should be doubly stochastic |
| 38 | + conv0 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=0), 0.0, rtol=rtol, atol=atol)) |
| 39 | + conv1 = keras.ops.all(keras.ops.isclose(keras.ops.logsumexp(plan, axis=1), 0.0, rtol=rtol, atol=atol)) |
| 40 | + return conv0 & conv1 |
| 41 | + |
| 42 | + def cond(_, plan): |
| 43 | + # break the while loop if the plan contains nans or is converged |
| 44 | + return ~(contains_nans(plan) | is_converged(plan)) |
| 45 | + |
| 46 | + def body(steps, plan): |
| 47 | + # Sinkhorn-Knopp: repeatedly normalize the transport plan along each dimension |
| 48 | + plan = keras.ops.log_softmax(plan, axis=0) |
| 49 | + plan = keras.ops.log_softmax(plan, axis=1) |
| 50 | + |
| 51 | + return steps + 1, plan |
| 52 | + |
| 53 | + steps = 0 |
| 54 | + steps, log_plan = keras.ops.while_loop(cond, body, (steps, log_plan), maximum_iterations=max_steps) |
| 55 | + |
| 56 | + def do_nothing(): |
| 57 | + pass |
| 58 | + |
| 59 | + def log_steps(): |
| 60 | + msg = "Log-Sinkhorn-Knopp converged after {:d} steps." |
| 61 | + |
| 62 | + logging.info(msg, steps) |
| 63 | + |
| 64 | + def warn_convergence(): |
| 65 | + marginals = keras.ops.logsumexp(log_plan, axis=0) |
| 66 | + deviations = keras.ops.abs(marginals) |
| 67 | + badness = 100.0 * keras.ops.exp(keras.ops.max(deviations)) |
| 68 | + |
| 69 | + msg = "Log-Sinkhorn-Knopp did not converge after {:d} steps (badness: {:.1f}%)." |
| 70 | + |
| 71 | + logging.warning(msg, max_steps, badness) |
| 72 | + |
| 73 | + def warn_nans(): |
| 74 | + msg = "Log-Sinkhorn-Knopp produced NaNs." |
| 75 | + logging.warning(msg) |
| 76 | + |
| 77 | + keras.ops.cond(contains_nans(log_plan), warn_nans, do_nothing) |
| 78 | + keras.ops.cond(is_converged(log_plan), log_steps, warn_convergence) |
| 79 | + |
| 80 | + return log_plan |
0 commit comments