Skip to content

Commit 5364a23

Browse files
committed
make optimal transport tests even more reliable
1 parent 036ec5d commit 5364a23

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def sinkhorn(
1616
cost: str | Tensor = "euclidean",
1717
seed: int = None,
1818
regularization: float = 1.0,
19-
max_steps: int = 10000,
19+
max_steps: int | None = 10_000,
2020
tolerance: float = 1e-6,
2121
numpy: bool = False,
2222
) -> (Tensor, Tensor):
@@ -88,7 +88,7 @@ def sinkhorn_indices(
8888
cost: str | Tensor = "euclidean",
8989
seed: int = None,
9090
regularization: float = 1.0,
91-
max_steps: int = 1000,
91+
max_steps: int | None = 10_000,
9292
tolerance: float = 1e-6,
9393
numpy: bool = False,
9494
) -> Tensor | np.ndarray:
@@ -111,7 +111,7 @@ def sinkhorn_indices(
111111
Default: 1.0
112112
113113
:param max_steps: Maximum number of iterations.
114-
Default: 1000
114+
Default: 10_000
115115
116116
:param tolerance: Absolute tolerance for convergence.
117117
Default: 1e-6
@@ -164,15 +164,13 @@ def sinkhorn_plan(
164164
165165
:param regularization: Regularization parameter.
166166
Controls the standard deviation of the Gaussian kernel.
167-
Default: 1.0
168167
169168
:param max_steps: Maximum number of iterations.
170-
Default: 1000
171169
172170
:param tolerance: Absolute tolerance for convergence.
173-
Default: 1e-6
174171
175172
:param numpy: Whether to use numpy or keras backend.
173+
Default: False
176174
177175
:return: Tensor of shape (n, m)
178176
The transport probabilities.

tests/test_utils/test_optimal_transport.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,30 @@
99
def test_jit_compile():
1010
import jax
1111

12-
x = keras.random.normal((100, 100), seed=0)
13-
y = keras.random.normal((100, 100), seed=1)
12+
x = keras.random.normal((128, 8), seed=0)
13+
y = keras.random.normal((128, 8), seed=1)
1414

1515
ot = jax.jit(optimal_transport, static_argnames=["regularization", "seed"])
16-
ot(x, y, regularization=0.1, seed=0)
16+
ot(x, y, regularization=1.0, seed=0, max_steps=10)
17+
18+
19+
def test_shapes():
20+
x = keras.random.normal((128, 8), seed=0)
21+
y = keras.random.normal((128, 8), seed=1)
22+
23+
ox, oy = optimal_transport(x, y, regularization=1.0, seed=0, max_steps=10)
24+
25+
assert keras.ops.shape(ox) == keras.ops.shape(x)
26+
assert keras.ops.shape(oy) == keras.ops.shape(y)
1727

1828

1929
def test_transport_cost_improves():
20-
x = keras.random.normal((100, 2), seed=0)
21-
y = keras.random.normal((100, 2), seed=1)
30+
x = keras.random.normal((32, 8), seed=0)
31+
y = keras.random.normal((32, 8), seed=1)
2232

2333
before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
2434

25-
x, y = optimal_transport(x, y, regularization=1e-3, seed=0)
35+
x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=None)
2636

2737
after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
2838

@@ -36,7 +46,7 @@ def test_assignment_is_optimal():
3646
# we could shuffle x and y, but flipping is a more reliable permutation
3747
y = keras.ops.flip(y, axis=0)
3848

39-
x, y = optimal_transport(x, y, regularization=1e-6, seed=0)
49+
x, y = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=1000)
4050

4151
cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
4252

0 commit comments

Comments
 (0)