Skip to content

Commit 3e4813a

Browse files
committed
address code review from Lars
1 parent 3c93679 commit 3e4813a

File tree

3 files changed

+54
-22
lines changed

3 files changed

+54
-22
lines changed

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def __init__(
2323
p: Sequence[float] = None,
2424
logits: Sequence[float] = None,
2525
use_mixed_batches: bool = True,
26-
key_conflicts: str | float = "drop",
26+
key_conflicts: str = "drop",
27+
fill_value: float = np.nan,
2728
shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None,
2829
):
2930
"""
@@ -43,11 +44,13 @@ def __init__(
4344
Whether to draw samples in a batch from different models.
4445
- If True (default), each sample in a batch may come from a different model.
4546
- If False, the entire batch is drawn from a single model, selected according to model probabilities.
46-
key_conflicts : {"drop"} | float, optional
47+
key_conflicts : str, optional
4748
Policy for handling keys that are missing in the output of some models, when using mixed batches.
4849
- "drop" (default): Drop conflicting keys from the batch output.
49-
- float: Fill missing keys with the specified value.
50-
- If neither "drop" nor a float is given, an error is raised when key conflicts are detected.
50+
- "fill": Fill missing keys with the specified value.
51+
- "error": An error is raised when key conflicts are detected.
52+
fill_value : float, optional
53+
If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument.
5154
shared_simulator : Simulator or Callable, optional
5255
A shared simulator whose outputs are passed to all model simulators. If a function is
5356
provided, it is wrapped in a `LambdaSimulator` with batching enabled.
@@ -77,6 +80,7 @@ def __init__(
7780
self.logits = logits
7881
self.use_mixed_batches = use_mixed_batches
7982
self.key_conflicts = key_conflicts
83+
self.fill_value = fill_value
8084
self._keys = None
8185

8286
@allow_batch_size
@@ -139,30 +143,22 @@ def _handle_key_conflicts(self, sims, batch_sizes):
139143
if all_keys == common_keys:
140144
return sims
141145

142-
# keep only common keys
143146
if self.key_conflicts == "drop":
144147
sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims]
145148
return sims
146-
147-
# try to fill with key_conflicts to shape of the values from other model
148-
if isinstance(self.key_conflicts, (float, int)):
149+
elif self.key_conflicts == "fill":
149150
combined_sims = {}
150151
for sim in sims:
151152
combined_sims = combined_sims | sim
152-
153153
for i, sim in enumerate(sims):
154154
for missing_key in missing_keys[i]:
155155
shape = combined_sims[missing_key].shape
156156
shape = list(shape)
157157
shape[0] = batch_sizes[i]
158-
159-
sim[missing_key] = np.full(shape=shape, fill_value=self.key_conflicts)
160-
158+
sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value)
161159
return sims
162-
163-
raise ValueError(
164-
"Key conflicts are found in model simulations and no valid `key_conflicts` policy was provided."
165-
)
160+
elif self.key_conflicts == "error":
161+
raise ValueError("Key conflicts are found in simulator outputs, cannot combine them into one batch.")
166162

167163
def _determine_key_conflicts(self, sims):
168164
# determine only once
@@ -184,11 +180,11 @@ def _determine_key_conflicts(self, sims):
184180
f"Incompatible simulator output. \
185181
The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}."
186182
)
187-
elif isinstance(self.key_conflicts, (float, int)):
183+
elif self.key_conflicts == "fill":
188184
logging.info(
189185
f"Incompatible simulator output. \
190186
Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \
191-
with value {self.key_conflicts}."
187+
with value {self.fill_value}."
192188
)
193189

194190
return self._keys

tests/test_simulators/conftest.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,32 @@ def likelihood(mu, n):
167167
return make_simulator([prior, likelihood], meta_fn=context)
168168

169169

170-
@pytest.fixture(params=["drop", np.nan])
171-
def multimodel(request):
170+
@pytest.fixture()
171+
def multimodel():
172+
from bayesflow.simulators import make_simulator, ModelComparisonSimulator
173+
174+
def context(batch_size):
175+
return dict(n=np.random.randint(10, 100))
176+
177+
def prior_0():
178+
return dict(mu=0)
179+
180+
def prior_1():
181+
return dict(mu=np.random.standard_normal())
182+
183+
def likelihood(n, mu):
184+
return dict(y=np.random.normal(mu, 1, n))
185+
186+
simulator_0 = make_simulator([prior_0, likelihood])
187+
simulator_1 = make_simulator([prior_1, likelihood])
188+
189+
simulator = ModelComparisonSimulator(simulators=[simulator_0, simulator_1], shared_simulator=context)
190+
191+
return simulator
192+
193+
194+
@pytest.fixture(params=["drop", "fill", "error"])
195+
def multimodel_key_conflicts(request):
172196
from bayesflow.simulators import make_simulator, ModelComparisonSimulator
173197

174198
rng = np.random.default_rng()

tests/test_simulators/test_simulators.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import keras
23
import numpy as np
34

@@ -52,8 +53,19 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu):
5253
def test_multimodel_sample(multimodel, batch_size):
5354
samples = multimodel.sample(batch_size)
5455

55-
if multimodel.key_conflicts == "drop":
56+
assert set(samples) == {"n", "mu", "y", "model_indices"}
57+
assert samples["mu"].shape == (batch_size, 1)
58+
assert samples["y"].shape == (batch_size, samples["n"])
59+
60+
61+
def test_multimodel_key_conflicts_sample(multimodel_key_conflicts, batch_size):
62+
if multimodel_key_conflicts.key_conflicts == "drop":
63+
samples = multimodel_key_conflicts.sample(batch_size)
5664
assert set(samples) == {"x", "model_indices"}
57-
else:
65+
elif multimodel_key_conflicts.key_conflicts == "fill":
66+
samples = multimodel_key_conflicts.sample(batch_size)
5867
assert set(samples) == {"x", "model_indices", "c", "w"}
5968
assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size
69+
elif multimodel_key_conflicts.key_conflicts == "error":
70+
with pytest.raises(Exception):
71+
samples = multimodel_key_conflicts.sample(batch_size)

0 commit comments

Comments
 (0)