Skip to content

Commit ec0ee2f

Browse files
committed
update dispatch tests for more coverage
1 parent 16491be commit ec0ee2f

File tree

1 file changed

+112
-143
lines changed

1 file changed

+112
-143
lines changed

tests/test_utils/test_dispatch.py

Lines changed: 112 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,201 +1,170 @@
11
import keras
22
import pytest
33

4-
# Import the dispatch functions
5-
from bayesflow.utils import find_network, find_permutation, find_pooling, find_recurrent_net
6-
from tests.utils import assert_allclose
4+
from bayesflow.utils import find_inference_network, find_distribution, find_summary_network
75

8-
# --- Tests for find_network.py ---
96

7+
# --- Tests for find_inference_network.py ---
108

11-
class DummyMLP:
12-
def __init__(self, *args, **kwargs):
13-
self.args = args
14-
self.kwargs = kwargs
159

10+
class DummyInferenceNetwork:
11+
def __init__(self, *a, **kw):
12+
self.args = a
13+
self.kwargs = kw
1614

17-
def test_find_network_with_string(monkeypatch):
18-
# Monkeypatch the MLP entry in bayesflow.networks
19-
monkeypatch.setattr("bayesflow.networks.MLP", DummyMLP)
20-
21-
net = find_network("mlp", 1, key="value")
22-
assert isinstance(net, DummyMLP)
23-
assert net.args == (1,)
24-
assert net.kwargs == {"key": "value"}
2515

16+
@pytest.mark.parametrize(
17+
"name,expected_class_path",
18+
[
19+
("coupling_flow", "bayesflow.networks.CouplingFlow"),
20+
("flow_matching", "bayesflow.networks.FlowMatching"),
21+
("consistency_model", "bayesflow.networks.ConsistencyModel"),
22+
],
23+
)
24+
def test_find_inference_network_by_name(monkeypatch, name, expected_class_path):
25+
# patch the expected class in bayesflow.networks
26+
components = expected_class_path.split(".")
27+
module_path = ".".join(components[:-1])
28+
class_name = components[-1]
2629

27-
def test_find_network_with_type():
28-
class CustomNet:
29-
def __init__(self, x):
30-
self.x = x
30+
dummy_cls = DummyInferenceNetwork
31+
monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls)
3132

32-
net = find_network(CustomNet, 42)
33-
assert isinstance(net, CustomNet)
34-
assert net.x == 42
33+
net = find_inference_network(name, 1, key="val")
34+
assert isinstance(net, DummyInferenceNetwork)
35+
assert net.args == (1,)
36+
assert net.kwargs == {"key": "val"}
3537

3638

37-
def test_find_network_with_keras_layer():
39+
def test_find_inference_network_by_keras_layer():
3840
layer = keras.layers.Dense(10)
39-
returned = find_network(layer)
40-
assert returned is layer
41-
42-
43-
def test_find_network_invalid_type():
44-
with pytest.raises(TypeError):
45-
find_network(123)
41+
result = find_inference_network(layer)
42+
assert result is layer
4643

4744

48-
# --- Tests for find_permutation.py ---
45+
def test_find_inference_network_by_keras_model():
46+
model = keras.models.Sequential()
47+
result = find_inference_network(model)
48+
assert result is model
4949

5050

51-
class DummyRandomPermutation:
52-
def __init__(self, *args, **kwargs):
53-
self.args = args
54-
self.kwargs = kwargs
51+
def test_find_inference_network_unknown_name():
52+
with pytest.raises(ValueError):
53+
find_inference_network("unknown_network_name")
5554

5655

57-
class DummySwap:
58-
def __init__(self, *args, **kwargs):
59-
self.args = args
60-
self.kwargs = kwargs
56+
def test_find_inference_network_invalid_type():
57+
with pytest.raises(TypeError):
58+
find_inference_network(12345)
6159

6260

63-
class DummyOrthogonalPermutation:
64-
def __init__(self, *args, **kwargs):
65-
self.args = args
66-
self.kwargs = kwargs
61+
# --- Tests for find_distribution.py ---
6762

6863

69-
def test_find_permutation_random(monkeypatch):
70-
type("dummy_mod", (), {"RandomPermutation": DummyRandomPermutation})
71-
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.RandomPermutation", DummyRandomPermutation)
72-
perm = find_permutation("random", 99, flag=True)
73-
assert isinstance(perm, DummyRandomPermutation)
74-
assert perm.args == (99,)
75-
assert perm.kwargs == {"flag": True}
64+
class DummyDistribution:
65+
def __init__(self, *a, **kw):
66+
self.args = a
67+
self.kwargs = kw
7668

7769

7870
@pytest.mark.parametrize(
79-
"name,dummy_cls",
80-
[("swap", DummySwap), ("learnable", DummyOrthogonalPermutation), ("orthogonal", DummyOrthogonalPermutation)],
71+
"name, expected_class_path",
72+
[
73+
("normal", "bayesflow.distributions.DiagonalNormal"),
74+
("student", "bayesflow.distributions.DiagonalStudentT"),
75+
("student-t", "bayesflow.distributions.DiagonalStudentT"),
76+
("student_t", "bayesflow.distributions.DiagonalStudentT"),
77+
],
8178
)
82-
def test_find_permutation_by_name(monkeypatch, name, dummy_cls):
83-
# Inject dummy classes for each permutation type
84-
if name == "swap":
85-
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.Swap", dummy_cls)
86-
else:
87-
monkeypatch.setattr("bayesflow.networks.coupling_flow.permutations.OrthogonalPermutation", dummy_cls)
88-
perm = find_permutation(name, "a", b="c")
89-
assert isinstance(perm, dummy_cls)
90-
assert perm.args == ("a",)
91-
assert perm.kwargs == {"b": "c"}
92-
79+
def test_find_distribution_by_name(monkeypatch, name, expected_class_path):
80+
components = expected_class_path.split(".")
81+
module_path = ".".join(components[:-1])
82+
class_name = components[-1]
9383

94-
def test_find_permutation_with_keras_layer():
95-
layer = keras.layers.Activation("relu")
96-
perm = find_permutation(layer)
97-
assert perm is layer
84+
dummy_cls = DummyDistribution
85+
monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls)
9886

87+
dist = find_distribution(name, 10, a=5)
88+
assert isinstance(dist, DummyDistribution)
89+
assert dist.args == (10,)
90+
assert dist.kwargs == {"a": 5}
9991

100-
def test_find_permutation_with_none():
101-
res = find_permutation(None)
102-
assert res is None
103-
104-
105-
def test_find_permutation_invalid_type():
106-
with pytest.raises(TypeError):
107-
find_permutation(3.14)
10892

93+
def test_find_distribution_none_returns_none():
94+
assert find_distribution(None) is None
10995

110-
# --- Tests for find_pooling.py ---
11196

97+
def test_find_distribution_with_keras_layer():
98+
layer = keras.layers.Dense(3)
99+
result = find_distribution(layer)
100+
assert result is layer
112101

113-
def dummy_pooling_constructor(*args, **kwargs):
114-
return {"args": args, "kwargs": kwargs}
115102

103+
def test_find_distribution_mixture_raises():
104+
with pytest.raises(ValueError):
105+
find_distribution("mixture")
116106

117-
def test_find_pooling_mean():
118-
pooling = find_pooling("mean")
119-
# Check that a keras Lambda layer is returned
120-
assert isinstance(pooling, keras.layers.Lambda)
121-
# Test that the lambda function produces a mean when applied to a sample tensor.
122-
123-
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
124-
# Keras Lambda layers expect tensors via call(), here we simply call the layer's function.
125-
result = pooling.call(sample)
126-
assert_allclose(result, keras.ops.mean(sample, axis=-2))
127-
128-
129-
@pytest.mark.parametrize("name,func", [("max", keras.ops.max), ("min", keras.ops.min)])
130-
def test_find_pooling_max_min(name, func):
131-
pooling = find_pooling(name)
132-
assert isinstance(pooling, keras.layers.Lambda)
133-
134-
sample = keras.ops.convert_to_tensor([[1, 2], [3, 4]])
135-
result = pooling.call(sample)
136-
assert_allclose(result, func(sample, axis=-2))
137-
138-
139-
def test_find_pooling_learnable(monkeypatch):
140-
# Monkey patch the PoolingByMultiHeadAttention in its module
141-
class DummyPoolingAttention:
142-
def __init__(self, *args, **kwargs):
143-
self.args = args
144-
self.kwargs = kwargs
145-
146-
monkeypatch.setattr("bayesflow.networks.transformers.pma.PoolingByMultiHeadAttention", DummyPoolingAttention)
147-
pooling = find_pooling("learnable", 7, option="test")
148-
assert isinstance(pooling, DummyPoolingAttention)
149-
assert pooling.args == (7,)
150-
assert pooling.kwargs == {"option": "test"}
151107

108+
def test_find_distribution_invalid_name():
109+
with pytest.raises(ValueError):
110+
find_distribution("invalid_name")
152111

153-
def test_find_pooling_with_constructor():
154-
# Passing a type should result in an instance.
155-
class DummyPooling:
156-
def __init__(self, data):
157-
self.data = data
158112

159-
pooling = find_pooling(DummyPooling, "dummy")
160-
assert isinstance(pooling, DummyPooling)
161-
assert pooling.data == "dummy"
113+
def test_find_distribution_invalid_type():
114+
with pytest.raises(TypeError):
115+
find_distribution(3.14)
162116

163117

164-
def test_find_pooling_with_keras_layer():
165-
layer = keras.layers.ReLU()
166-
pooling = find_pooling(layer)
167-
assert pooling is layer
118+
# --- Tests for find_summary_network.py ---
168119

169120

170-
def test_find_pooling_invalid_type():
171-
with pytest.raises(TypeError):
172-
find_pooling(123)
121+
class DummySummaryNetwork:
122+
def __init__(self, *a, **kw):
123+
self.args = a
124+
self.kwargs = kw
173125

174126

175-
# --- Tests for find_recurrent_net.py ---
127+
@pytest.mark.parametrize(
128+
"name,expected_class_path",
129+
[
130+
("deep_set", "bayesflow.networks.DeepSet"),
131+
("set_transformer", "bayesflow.networks.SetTransformer"),
132+
("fusion_transformer", "bayesflow.networks.FusionTransformer"),
133+
("time_series_transformer", "bayesflow.networks.TimeSeriesTransformer"),
134+
("time_series_network", "bayesflow.networks.TimeSeriesNetwork"),
135+
],
136+
)
137+
def test_find_summary_network_by_name(monkeypatch, name, expected_class_path):
138+
components = expected_class_path.split(".")
139+
module_path = ".".join(components[:-1])
140+
class_name = components[-1]
176141

142+
dummy_cls = DummySummaryNetwork
143+
monkeypatch.setattr(f"{module_path}.{class_name}", dummy_cls)
177144

178-
def test_find_recurrent_net_lstm():
179-
constructor = find_recurrent_net("lstm")
180-
assert constructor is keras.layers.LSTM
145+
net = find_summary_network(name, 22, flag=True)
146+
assert isinstance(net, DummySummaryNetwork)
147+
assert net.args == (22,)
148+
assert net.kwargs == {"flag": True}
181149

182150

183-
def test_find_recurrent_net_gru():
184-
constructor = find_recurrent_net("gru")
185-
assert constructor is keras.layers.GRU
151+
def test_find_summary_network_by_keras_layer():
152+
layer = keras.layers.Dense(1)
153+
out = find_summary_network(layer)
154+
assert out is layer
186155

187156

188-
def test_find_recurrent_net_with_keras_layer():
189-
layer = keras.layers.SimpleRNN(5)
190-
net = find_recurrent_net(layer)
191-
assert net is layer
157+
def test_find_summary_network_by_keras_model():
158+
model = keras.models.Sequential()
159+
out = find_summary_network(model)
160+
assert out is model
192161

193162

194-
def test_find_recurrent_net_invalid_name():
163+
def test_find_summary_network_unknown_name():
195164
with pytest.raises(ValueError):
196-
find_recurrent_net("invalid_net")
165+
find_summary_network("unknown_summary_net")
197166

198167

199-
def test_find_recurrent_net_invalid_type():
168+
def test_find_summary_network_invalid_type():
200169
with pytest.raises(TypeError):
201-
find_recurrent_net(3.1415)
170+
find_summary_network(0.1234)

0 commit comments

Comments
 (0)