Skip to content

Commit efbf18b

Browse files
committed
add optimal transport tests
1 parent 8b95883 commit efbf18b

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import keras
2+
import pytest
3+
4+
from bayesflow.utils import optimal_transport
5+
from tests.utils import assert_allclose
6+
7+
8+
@pytest.mark.jax
9+
def test_jit_compile():
10+
import jax
11+
12+
x = keras.random.normal((100, 100), seed=0)
13+
y = keras.random.normal((100, 100), seed=1)
14+
15+
ot = jax.jit(optimal_transport, static_argnames=["regularization", "seed"])
16+
ot(x, y, regularization=0.1, seed=0)
17+
18+
19+
def test_transport_cost_improves():
20+
x = keras.random.normal((100, 100), seed=0)
21+
y = keras.random.normal((100, 100), seed=1)
22+
23+
before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
24+
25+
x, y = optimal_transport(x, y, regularization=0.1, seed=0)
26+
27+
after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
28+
29+
assert after_cost < before_cost
30+
31+
32+
def test_assignment_is_optimal():
33+
x = keras.ops.stack([keras.ops.linspace(-1, 1, 10), keras.ops.linspace(-1, 1, 10)])
34+
y = keras.ops.copy(x)
35+
36+
# we could shuffle x and y, but flipping is a more reliable permutation
37+
y = keras.ops.flip(y, axis=0)
38+
39+
x, y = optimal_transport(x, y, regularization=1e-6, seed=0)
40+
41+
cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
42+
43+
assert_allclose(cost, 0.0)

0 commit comments

Comments
 (0)