|
| 1 | +import keras |
| 2 | +import pytest |
| 3 | + |
| 4 | +from bayesflow.utils import optimal_transport |
| 5 | +from tests.utils import assert_allclose |
| 6 | + |
| 7 | + |
| 8 | +@pytest.mark.jax |
| 9 | +def test_jit_compile(): |
| 10 | + import jax |
| 11 | + |
| 12 | + x = keras.random.normal((100, 100), seed=0) |
| 13 | + y = keras.random.normal((100, 100), seed=1) |
| 14 | + |
| 15 | + ot = jax.jit(optimal_transport, static_argnames=["regularization", "seed"]) |
| 16 | + ot(x, y, regularization=0.1, seed=0) |
| 17 | + |
| 18 | + |
| 19 | +def test_transport_cost_improves(): |
| 20 | + x = keras.random.normal((100, 100), seed=0) |
| 21 | + y = keras.random.normal((100, 100), seed=1) |
| 22 | + |
| 23 | + before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) |
| 24 | + |
| 25 | + x, y = optimal_transport(x, y, regularization=0.1, seed=0) |
| 26 | + |
| 27 | + after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) |
| 28 | + |
| 29 | + assert after_cost < before_cost |
| 30 | + |
| 31 | + |
| 32 | +def test_assignment_is_optimal(): |
| 33 | + x = keras.ops.stack([keras.ops.linspace(-1, 1, 10), keras.ops.linspace(-1, 1, 10)]) |
| 34 | + y = keras.ops.copy(x) |
| 35 | + |
| 36 | + # we could shuffle x and y, but flipping is a more reliable permutation |
| 37 | + y = keras.ops.flip(y, axis=0) |
| 38 | + |
| 39 | + x, y = optimal_transport(x, y, regularization=1e-6, seed=0) |
| 40 | + |
| 41 | + cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1)) |
| 42 | + |
| 43 | + assert_allclose(cost, 0.0) |
0 commit comments