Skip to content

Commit e95f181

Browse files
committed
add dispatch tests
1 parent f804865 commit e95f181

File tree

1 file changed

+202
-0
lines changed

1 file changed

+202
-0
lines changed

tests/test_utils/test_dispatch.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import keras
2+
import pytest
3+
4+
# Import the dispatch functions
5+
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
6+
7+
# --- Tests for find_network.py ---
8+
9+
10+
class DummyMLP:
11+
def __init__(self, *args, **kwargs):
12+
self.args = args
13+
self.kwargs = kwargs
14+
15+
16+
def test_find_network_with_string(monkeypatch):
17+
# Monkeypatch the MLP entry in bayesflow.networks
18+
monkeypatch.setattr("bayesflow.networks.MLP", DummyMLP)
19+
20+
net = find_network("mlp", 1, key="value")
21+
assert isinstance(net, DummyMLP)
22+
assert net.args == (1,)
23+
assert net.kwargs == {"key": "value"}
24+
25+
26+
def test_find_network_with_type():
27+
class CustomNet:
28+
def __init__(self, x):
29+
self.x = x
30+
31+
net = find_network(CustomNet, 42)
32+
assert isinstance(net, CustomNet)
33+
assert net.x == 42
34+
35+
36+
def test_find_network_with_keras_layer():
37+
layer = keras.layers.Dense(10)
38+
returned = find_network(layer)
39+
assert returned is layer
40+
41+
42+
def test_find_network_invalid_type():
43+
with pytest.raises(TypeError):
44+
find_network(123)
45+
46+
47+
# --- Tests for find_permutation.py ---
48+
49+
50+
class DummyRandomPermutation:
51+
def __init__(self, *args, **kwargs):
52+
self.args = args
53+
self.kwargs = kwargs
54+
55+
56+
class DummySwap:
57+
def __init__(self, *args, **kwargs):
58+
self.args = args
59+
self.kwargs = kwargs
60+
61+
62+
class DummyOrthogonalPermutation:
63+
def __init__(self, *args, **kwargs):
64+
self.args = args
65+
self.kwargs = kwargs
66+
67+
68+
def test_find_permutation_random(monkeypatch):
69+
type("dummy_mod", (), {"RandomPermutation": DummyRandomPermutation})
70+
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.RandomPermutation", DummyRandomPermutation)
71+
perm = find_permutation("random", 99, flag=True)
72+
assert isinstance(perm, DummyRandomPermutation)
73+
assert perm.args == (99,)
74+
assert perm.kwargs == {"flag": True}
75+
76+
77+
@pytest.mark.parametrize(
78+
"name,dummy_cls",
79+
[("swap", DummySwap), ("learnable", DummyOrthogonalPermutation), ("orthogonal", DummyOrthogonalPermutation)],
80+
)
81+
def test_find_permutation_by_name(monkeypatch, name, dummy_cls):
82+
# Inject dummy classes for each permutation type
83+
if name == "swap":
84+
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.Swap", dummy_cls)
85+
else:
86+
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.OrthogonalPermutation", dummy_cls)
87+
perm = find_permutation(name, "a", b="c")
88+
assert isinstance(perm, dummy_cls)
89+
assert perm.args == ("a",)
90+
assert perm.kwargs == {"b": "c"}
91+
92+
93+
def test_find_permutation_with_keras_layer():
94+
layer = keras.layers.Activation("relu")
95+
perm = find_permutation(layer)
96+
assert perm is layer
97+
98+
99+
def test_find_permutation_with_none():
100+
res = find_permutation(None)
101+
assert res is None
102+
103+
104+
def test_find_permutation_invalid_type():
105+
with pytest.raises(TypeError):
106+
find_permutation(3.14)
107+
108+
109+
# --- Tests for find_pooling.py ---
110+
111+
112+
def dummy_pooling_constructor(*args, **kwargs):
113+
return {"args": args, "kwargs": kwargs}
114+
115+
116+
def test_find_pooling_mean():
117+
pooling = find_pooling("mean")
118+
# Check that a keras Lambda layer is returned
119+
assert isinstance(pooling, keras.layers.Lambda)
120+
# Test that the lambda function produces a mean when applied to a sample tensor.
121+
import numpy as np
122+
123+
sample = np.array([[1, 2], [3, 4]])
124+
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
125+
result = pooling.call(sample)
126+
np.testing.assert_allclose(result, sample.mean(axis=-2))
127+
128+
129+
@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
130+
def test_find_pooling_max_min(name, func):
131+
pooling = find_pooling(name)
132+
assert isinstance(pooling, keras.layers.Lambda)
133+
import numpy as np
134+
135+
sample = np.array([[1, 2], [3, 4]])
136+
result = pooling.call(sample)
137+
np.testing.assert_allclose(result, func(sample, axis=-2))
138+
139+
140+
def test_find_pooling_learnable(monkeypatch):
141+
# Monkey patch the PoolingByMultiHeadAttention in its module
142+
class DummyPoolingAttention:
143+
def __init__(self, *args, **kwargs):
144+
self.args = args
145+
self.kwargs = kwargs
146+
147+
monkeypatch.setattr("bayesflow.networks.transformers.pma.PoolingByMultiHeadAttention", DummyPoolingAttention)
148+
pooling = find_pooling("learnable", 7, option="test")
149+
assert isinstance(pooling, DummyPoolingAttention)
150+
assert pooling.args == (7,)
151+
assert pooling.kwargs == {"option": "test"}
152+
153+
154+
def test_find_pooling_with_constructor():
155+
# Passing a type should result in an instance.
156+
class DummyPooling:
157+
def __init__(self, data):
158+
self.data = data
159+
160+
pooling = find_pooling(DummyPooling, "dummy")
161+
assert isinstance(pooling, DummyPooling)
162+
assert pooling.data == "dummy"
163+
164+
165+
def test_find_pooling_with_keras_layer():
166+
layer = keras.layers.ReLU()
167+
pooling = find_pooling(layer)
168+
assert pooling is layer
169+
170+
171+
def test_find_pooling_invalid_type():
172+
with pytest.raises(TypeError):
173+
find_pooling(123)
174+
175+
176+
# --- Tests for find_recurrent_net.py ---
177+
178+
179+
def test_find_recurrent_net_lstm():
180+
constructor = find_recurrent_net("lstm")
181+
assert constructor is keras.layers.LSTM
182+
183+
184+
def test_find_recurrent_net_gru():
185+
constructor = find_recurrent_net("gru")
186+
assert constructor is keras.layers.GRU
187+
188+
189+
def test_find_recurrent_net_with_keras_layer():
190+
layer = keras.layers.SimpleRNN(5)
191+
net = find_recurrent_net(layer)
192+
assert net is layer
193+
194+
195+
def test_find_recurrent_net_invalid_name():
196+
with pytest.raises(ValueError):
197+
find_recurrent_net("invalid_net")
198+
199+
200+
def test_find_recurrent_net_invalid_type():
201+
with pytest.raises(TypeError):
202+
find_recurrent_net(3.1415)

0 commit comments

Comments
 (0)