Skip to content

Commit 91e84ba

Browse files
committed
remove random optimal transport method
1 parent 5364a23 commit 91e84ba

File tree

3 files changed

+25
-45
lines changed

3 files changed

+25
-45
lines changed

bayesflow/utils/optimal_transport/optimal_transport.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from bayesflow.types import Tensor
22

33
from .hungarian import hungarian
4-
from .random import random
54
from .sinkhorn import sinkhorn
65

76

@@ -39,7 +38,6 @@ def optimal_transport(
3938
"hungarian": hungarian,
4039
"sinkhorn": sinkhorn,
4140
"sinkhorn_knopp": sinkhorn,
42-
"random": random,
4341
}
4442

4543
method = method.lower()

bayesflow/utils/optimal_transport/random.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

tests/test_utils/test_optimal_transport.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def test_shapes():
2727

2828

2929
def test_transport_cost_improves():
30-
x = keras.random.normal((32, 8), seed=0)
31-
y = keras.random.normal((32, 8), seed=1)
30+
x = keras.random.normal((1024, 2), seed=0)
31+
y = keras.random.normal((1024, 2), seed=1)
3232

3333
before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
3434

@@ -40,14 +40,26 @@ def test_transport_cost_improves():
4040

4141

4242
def test_assignment_is_optimal():
43-
x = keras.ops.stack([keras.ops.linspace(-1, 1, 10), keras.ops.linspace(-1, 1, 10)])
44-
y = keras.ops.copy(x)
45-
46-
# we could shuffle x and y, but flipping is a more reliable permutation
47-
y = keras.ops.flip(y, axis=0)
48-
49-
x, y = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=1000)
50-
51-
cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
52-
53-
assert_allclose(cost, 0.0)
43+
x = keras.ops.convert_to_tensor(
44+
[
45+
[-1, 2],
46+
[-1, 1],
47+
[-1, 0],
48+
[-1, -1],
49+
[-1, -2],
50+
]
51+
)
52+
optimal_y = keras.ops.convert_to_tensor(
53+
[
54+
[1, 2],
55+
[1, 1],
56+
[1, 0],
57+
[1, -1],
58+
[1, -2],
59+
]
60+
)
61+
y = keras.random.shuffle(optimal_y, axis=0, seed=0)
62+
63+
x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=None, scale_regularization=False)
64+
65+
assert_allclose(x, y)

0 commit comments

Comments
 (0)