Skip to content

Commit 86f6f41

Browse files
committed
drop or fill missing keys from the output
1 parent 5b5363c commit 86f6f41

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from bayesflow.utils.decorators import allow_batch_size
77

88
from bayesflow.utils import numpy_utils as npu
9+
from bayesflow.utils import logging
910

1011
from types import FunctionType
1112

@@ -22,6 +23,7 @@ def __init__(
2223
p: Sequence[float] = None,
2324
logits: Sequence[float] = None,
2425
use_mixed_batches: bool = True,
26+
key_conflicts: str | float = "drop",
2527
shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None,
2628
):
2729
"""
@@ -38,8 +40,14 @@ def __init__(
3840
A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
3941
If neither `p` nor `logits` is provided, defaults to uniform logits.
4042
use_mixed_batches : bool, optional
41-
If True, samples in a batch are drawn from different models. If False, the entire batch
42-
is drawn from a single model chosen according to the model probabilities. Default is True.
43+
Whether to draw samples in a batch from different models.
44+
- If True (default), each sample in a batch may come from a different model.
45+
- If False, the entire batch is drawn from a single model, selected according to model probabilities.
46+
key_conflicts : {"drop"} | float, optional
47+
Policy for handling keys that are missing in the output of some models, when using mixed batches.
48+
- "drop" (default): Drop conflicting keys from the batch output.
49+
- float: Fill missing keys with the specified value.
50+
- If neither "drop" nor a float is given, an error is raised when key conflicts are detected.
4351
shared_simulator : Simulator or Callable, optional
4452
A shared simulator whose outputs are passed to all model simulators. If a function is
4553
provided, it is wrapped in a `LambdaSimulator` with batching enabled.
@@ -68,6 +76,8 @@ def __init__(
6876

6977
self.logits = logits
7078
self.use_mixed_batches = use_mixed_batches
79+
self.key_conflicts = key_conflicts
80+
self._keys = None
7181

7282
@allow_batch_size
7383
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
@@ -105,6 +115,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
105115
sims = [
106116
simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0
107117
]
118+
sims = self._handle_key_conflicts(sims, model_counts)
108119
sims = tree_concatenate(sims, numpy=True)
109120
data |= sims
110121

@@ -118,3 +129,66 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
118129
model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models)
119130

120131
return data | {"model_indices": model_indices}
132+
133+
def _handle_key_conflicts(self, sims, batch_sizes):
134+
batch_sizes = [b for b in batch_sizes if b > 0]
135+
136+
keys, all_keys, common_keys, missing_keys = self._determine_key_conflicts(sims=sims)
137+
138+
# all sims have the same keys
139+
if all_keys == common_keys:
140+
return sims
141+
142+
# keep only common keys
143+
if self.key_conflicts == "drop":
144+
sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims]
145+
return sims
146+
147+
# try to fill values with key_conflicts to shape of sims from other models
148+
if isinstance(self.key_conflicts, (float, int)):
149+
combined_sims = {}
150+
for sim in sims:
151+
combined_sims = combined_sims | sim
152+
153+
for i, sim in enumerate(sims):
154+
for missing_key in missing_keys[i]:
155+
shape = combined_sims[missing_key].shape
156+
shape = [s for s in shape]
157+
shape[0] = batch_sizes[i]
158+
159+
sim[missing_key] = np.full(shape=shape, fill_value=self.key_conflicts)
160+
161+
return sims
162+
163+
raise ValueError(
164+
"Key conflicts are found in model simulations and no valid `key_conflicts` policy was provided."
165+
)
166+
167+
def _determine_key_conflicts(self, sims):
168+
# determine only once
169+
if self._keys is not None:
170+
return self._keys
171+
172+
keys = [set(sim.keys()) for sim in sims]
173+
all_keys = set.union(*keys)
174+
common_keys = set.intersection(*keys)
175+
missing_keys = [all_keys - k for k in keys]
176+
177+
self._keys = keys, all_keys, common_keys, missing_keys
178+
179+
if all_keys == common_keys:
180+
return self._keys
181+
182+
if self.key_conflicts == "drop":
183+
logging.info(
184+
f"Incompatible simulator output. \
185+
The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}."
186+
)
187+
elif isinstance(self.key_conflicts, (float, int)):
188+
logging.info(
189+
f"Incompatible simulator output. \
190+
Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \
191+
with value {self.key_conflicts}."
192+
)
193+
194+
return self._keys

0 commit comments

Comments
 (0)