Skip to content

Commit 3508fb5

Browse files
committed
Merge branch 'dev' of https://github.com/stefanradev93/BayesFlow into dev
2 parents 873ed8f + 5364a23 commit 3508fb5

File tree

5 files changed

+260
-8
lines changed

5 files changed

+260
-8
lines changed

bayesflow/utils/optimal_transport/sinkhorn.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def sinkhorn(
1616
cost: str | Tensor = "euclidean",
1717
seed: int = None,
1818
regularization: float = 1.0,
19-
max_steps: int = 10000,
19+
max_steps: int | None = 10_000,
2020
tolerance: float = 1e-6,
2121
numpy: bool = False,
2222
) -> (Tensor, Tensor):
@@ -88,7 +88,7 @@ def sinkhorn_indices(
8888
cost: str | Tensor = "euclidean",
8989
seed: int = None,
9090
regularization: float = 1.0,
91-
max_steps: int = 1000,
91+
max_steps: int | None = 10_000,
9292
tolerance: float = 1e-6,
9393
numpy: bool = False,
9494
) -> Tensor | np.ndarray:
@@ -111,7 +111,7 @@ def sinkhorn_indices(
111111
Default: 1.0
112112
113113
:param max_steps: Maximum number of iterations.
114-
Default: 1000
114+
Default: 10_000
115115
116116
:param tolerance: Absolute tolerance for convergence.
117117
Default: 1e-6
@@ -164,15 +164,13 @@ def sinkhorn_plan(
164164
165165
:param regularization: Regularization parameter.
166166
Controls the standard deviation of the Gaussian kernel.
167-
Default: 1.0
168167
169168
:param max_steps: Maximum number of iterations.
170-
Default: 1000
171169
172170
:param tolerance: Absolute tolerance for convergence.
173-
Default: 1e-6
174171
175172
:param numpy: Whether to use numpy or keras backend.
173+
Default: False
176174
177175
:return: Tensor of shape (n, m)
178176
The transport probabilities.

tests/test_utils/test_dispatch.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import keras
2+
import pytest
3+
4+
# Import the dispatch functions
5+
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
6+
7+
# --- Tests for find_network.py ---
8+
9+
10+
class DummyMLP:
11+
def __init__(self, *args, **kwargs):
12+
self.args = args
13+
self.kwargs = kwargs
14+
15+
16+
def test_find_network_with_string(monkeypatch):
17+
# Monkeypatch the MLP entry in bayesflow.networks
18+
monkeypatch.setattr("bayesflow.networks.MLP", DummyMLP)
19+
20+
net = find_network("mlp", 1, key="value")
21+
assert isinstance(net, DummyMLP)
22+
assert net.args == (1,)
23+
assert net.kwargs == {"key": "value"}
24+
25+
26+
def test_find_network_with_type():
27+
class CustomNet:
28+
def __init__(self, x):
29+
self.x = x
30+
31+
net = find_network(CustomNet, 42)
32+
assert isinstance(net, CustomNet)
33+
assert net.x == 42
34+
35+
36+
def test_find_network_with_keras_layer():
37+
layer = keras.layers.Dense(10)
38+
returned = find_network(layer)
39+
assert returned is layer
40+
41+
42+
def test_find_network_invalid_type():
43+
with pytest.raises(TypeError):
44+
find_network(123)
45+
46+
47+
# --- Tests for find_permutation.py ---
48+
49+
50+
class DummyRandomPermutation:
51+
def __init__(self, *args, **kwargs):
52+
self.args = args
53+
self.kwargs = kwargs
54+
55+
56+
class DummySwap:
57+
def __init__(self, *args, **kwargs):
58+
self.args = args
59+
self.kwargs = kwargs
60+
61+
62+
class DummyOrthogonalPermutation:
63+
def __init__(self, *args, **kwargs):
64+
self.args = args
65+
self.kwargs = kwargs
66+
67+
68+
def test_find_permutation_random(monkeypatch):
69+
type("dummy_mod", (), {"RandomPermutation": DummyRandomPermutation})
70+
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.RandomPermutation", DummyRandomPermutation)
71+
perm = find_permutation("random", 99, flag=True)
72+
assert isinstance(perm, DummyRandomPermutation)
73+
assert perm.args == (99,)
74+
assert perm.kwargs == {"flag": True}
75+
76+
77+
@pytest.mark.parametrize(
78+
"name,dummy_cls",
79+
[("swap", DummySwap), ("learnable", DummyOrthogonalPermutation), ("orthogonal", DummyOrthogonalPermutation)],
80+
)
81+
def test_find_permutation_by_name(monkeypatch, name, dummy_cls):
82+
# Inject dummy classes for each permutation type
83+
if name == "swap":
84+
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.Swap", dummy_cls)
85+
else:
86+
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.OrthogonalPermutation", dummy_cls)
87+
perm = find_permutation(name, "a", b="c")
88+
assert isinstance(perm, dummy_cls)
89+
assert perm.args == ("a",)
90+
assert perm.kwargs == {"b": "c"}
91+
92+
93+
def test_find_permutation_with_keras_layer():
94+
layer = keras.layers.Activation("relu")
95+
perm = find_permutation(layer)
96+
assert perm is layer
97+
98+
99+
def test_find_permutation_with_none():
100+
res = find_permutation(None)
101+
assert res is None
102+
103+
104+
def test_find_permutation_invalid_type():
105+
with pytest.raises(TypeError):
106+
find_permutation(3.14)
107+
108+
109+
# --- Tests for find_pooling.py ---
110+
111+
112+
def dummy_pooling_constructor(*args, **kwargs):
113+
return {"args": args, "kwargs": kwargs}
114+
115+
116+
def test_find_pooling_mean():
117+
pooling = find_pooling("mean")
118+
# Check that a keras Lambda layer is returned
119+
assert isinstance(pooling, keras.layers.Lambda)
120+
# Test that the lambda function produces a mean when applied to a sample tensor.
121+
import numpy as np
122+
123+
sample = np.array([[1, 2], [3, 4]])
124+
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
125+
result = pooling.call(sample)
126+
np.testing.assert_allclose(result, sample.mean(axis=-2))
127+
128+
129+
@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
130+
def test_find_pooling_max_min(name, func):
131+
pooling = find_pooling(name)
132+
assert isinstance(pooling, keras.layers.Lambda)
133+
import numpy as np
134+
135+
sample = np.array([[1, 2], [3, 4]])
136+
result = pooling.call(sample)
137+
np.testing.assert_allclose(result, func(sample, axis=-2))
138+
139+
140+
def test_find_pooling_learnable(monkeypatch):
141+
# Monkey patch the PoolingByMultiHeadAttention in its module
142+
class DummyPoolingAttention:
143+
def __init__(self, *args, **kwargs):
144+
self.args = args
145+
self.kwargs = kwargs
146+
147+
monkeypatch.setattr("bayesflow.networks.transformers.pma.PoolingByMultiHeadAttention", DummyPoolingAttention)
148+
pooling = find_pooling("learnable", 7, option="test")
149+
assert isinstance(pooling, DummyPoolingAttention)
150+
assert pooling.args == (7,)
151+
assert pooling.kwargs == {"option": "test"}
152+
153+
154+
def test_find_pooling_with_constructor():
155+
# Passing a type should result in an instance.
156+
class DummyPooling:
157+
def __init__(self, data):
158+
self.data = data
159+
160+
pooling = find_pooling(DummyPooling, "dummy")
161+
assert isinstance(pooling, DummyPooling)
162+
assert pooling.data == "dummy"
163+
164+
165+
def test_find_pooling_with_keras_layer():
166+
layer = keras.layers.ReLU()
167+
pooling = find_pooling(layer)
168+
assert pooling is layer
169+
170+
171+
def test_find_pooling_invalid_type():
172+
with pytest.raises(TypeError):
173+
find_pooling(123)
174+
175+
176+
# --- Tests for find_recurrent_net.py ---
177+
178+
179+
def test_find_recurrent_net_lstm():
180+
constructor = find_recurrent_net("lstm")
181+
assert constructor is keras.layers.LSTM
182+
183+
184+
def test_find_recurrent_net_gru():
185+
constructor = find_recurrent_net("gru")
186+
assert constructor is keras.layers.GRU
187+
188+
189+
def test_find_recurrent_net_with_keras_layer():
190+
layer = keras.layers.SimpleRNN(5)
191+
net = find_recurrent_net(layer)
192+
assert net is layer
193+
194+
195+
def test_find_recurrent_net_invalid_name():
196+
with pytest.raises(ValueError):
197+
find_recurrent_net("invalid_net")
198+
199+
200+
def test_find_recurrent_net_invalid_type():
201+
with pytest.raises(TypeError):
202+
find_recurrent_net(3.1415)

tests/utils/ecdf.py renamed to tests/test_utils/test_ecdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_data():
1515

1616
def test_fractional_ranks(test_data):
1717
post_samples, prior_samples, _ = test_data
18-
# Compute expected result manually
18+
# Compute the expected result manually
1919
expected = np.mean(post_samples < prior_samples[:, np.newaxis, :], axis=1)
2020
result = fractional_ranks(post_samples, prior_samples)
2121
np.testing.assert_almost_equal(result, expected, decimal=6)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import keras
2+
import pytest
3+
4+
from bayesflow.utils import optimal_transport
5+
from tests.utils import assert_allclose
6+
7+
8+
@pytest.mark.jax
9+
def test_jit_compile():
10+
import jax
11+
12+
x = keras.random.normal((128, 8), seed=0)
13+
y = keras.random.normal((128, 8), seed=1)
14+
15+
ot = jax.jit(optimal_transport, static_argnames=["regularization", "seed"])
16+
ot(x, y, regularization=1.0, seed=0, max_steps=10)
17+
18+
19+
def test_shapes():
20+
x = keras.random.normal((128, 8), seed=0)
21+
y = keras.random.normal((128, 8), seed=1)
22+
23+
ox, oy = optimal_transport(x, y, regularization=1.0, seed=0, max_steps=10)
24+
25+
assert keras.ops.shape(ox) == keras.ops.shape(x)
26+
assert keras.ops.shape(oy) == keras.ops.shape(y)
27+
28+
29+
def test_transport_cost_improves():
30+
x = keras.random.normal((32, 8), seed=0)
31+
y = keras.random.normal((32, 8), seed=1)
32+
33+
before_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
34+
35+
x, y = optimal_transport(x, y, regularization=0.1, seed=0, max_steps=None)
36+
37+
after_cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
38+
39+
assert after_cost < before_cost
40+
41+
42+
def test_assignment_is_optimal():
43+
x = keras.ops.stack([keras.ops.linspace(-1, 1, 10), keras.ops.linspace(-1, 1, 10)])
44+
y = keras.ops.copy(x)
45+
46+
# we could shuffle x and y, but flipping is a more reliable permutation
47+
y = keras.ops.flip(y, axis=0)
48+
49+
x, y = optimal_transport(x, y, regularization=1e-3, seed=0, max_steps=1000)
50+
51+
cost = keras.ops.sum(keras.ops.norm(x - y, axis=-1))
52+
53+
assert_allclose(cost, 0.0)

tests/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .assertions import *
22
from .callbacks import *
33
from .check_combinations import *
4-
from .ecdf import *
54
from .jupyter import *
65
from .ops import *

0 commit comments

Comments
 (0)