66from bayesflow .utils .decorators import allow_batch_size
77
88from bayesflow .utils import numpy_utils as npu
9+ from bayesflow .utils import logging
910
1011from types import FunctionType
1112
@@ -22,6 +23,7 @@ def __init__(
2223 p : Sequence [float ] = None ,
2324 logits : Sequence [float ] = None ,
2425 use_mixed_batches : bool = True ,
26+ key_conflicts : str | float = "drop" ,
2527 shared_simulator : Simulator | Callable [[Sequence [int ]], dict [str , any ]] = None ,
2628 ):
2729 """
@@ -38,8 +40,14 @@ def __init__(
3840 A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
3941 If neither `p` nor `logits` is provided, defaults to uniform logits.
4042 use_mixed_batches : bool, optional
41- If True, samples in a batch are drawn from different models. If False, the entire batch
42- is drawn from a single model chosen according to the model probabilities. Default is True.
43+ Whether to draw samples in a batch from different models.
44+ - If True (default), each sample in a batch may come from a different model.
45+ - If False, the entire batch is drawn from a single model, selected according to model probabilities.
46+ key_conflicts : {"drop"} | float, optional
47+ Policy for handling keys that are missing in the output of some models, when using mixed batches.
48+ - "drop" (default): Drop conflicting keys from the batch output.
49+ - float: Fill missing keys with the specified value.
50+ - If neither "drop" nor a float is given, an error is raised when key conflicts are detected.
4351 shared_simulator : Simulator or Callable, optional
4452 A shared simulator whose outputs are passed to all model simulators. If a function is
4553 provided, it is wrapped in a `LambdaSimulator` with batching enabled.
@@ -68,6 +76,8 @@ def __init__(
6876
6977 self .logits = logits
7078 self .use_mixed_batches = use_mixed_batches
79+ self .key_conflicts = key_conflicts
80+ self ._keys = None
7181
7282 @allow_batch_size
7383 def sample (self , batch_shape : Shape , ** kwargs ) -> dict [str , np .ndarray ]:
@@ -105,6 +115,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
105115 sims = [
106116 simulator .sample (n , ** (kwargs | data )) for simulator , n in zip (self .simulators , model_counts ) if n > 0
107117 ]
118+ sims = self ._handle_key_conflicts (sims , model_counts )
108119 sims = tree_concatenate (sims , numpy = True )
109120 data |= sims
110121
@@ -118,3 +129,66 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
118129 model_indices = npu .one_hot (np .full (batch_shape , model_index , dtype = "int32" ), num_models )
119130
120131 return data | {"model_indices" : model_indices }
132+
133+ def _handle_key_conflicts (self , sims , batch_sizes ):
134+ batch_sizes = [b for b in batch_sizes if b > 0 ]
135+
136+ keys , all_keys , common_keys , missing_keys = self ._determine_key_conflicts (sims = sims )
137+
138+ # all sims have the same keys
139+ if all_keys == common_keys :
140+ return sims
141+
142+ # keep only common keys
143+ if self .key_conflicts == "drop" :
144+ sims = [{k : v for k , v in sim .items () if k in common_keys } for sim in sims ]
145+ return sims
146+
147+ # try to fill values with key_conflicts to shape of sims from other models
148+ if isinstance (self .key_conflicts , (float , int )):
149+ combined_sims = {}
150+ for sim in sims :
151+ combined_sims = combined_sims | sim
152+
153+ for i , sim in enumerate (sims ):
154+ for missing_key in missing_keys [i ]:
155+ shape = combined_sims [missing_key ].shape
156+ shape = [s for s in shape ]
157+ shape [0 ] = batch_sizes [i ]
158+
159+ sim [missing_key ] = np .full (shape = shape , fill_value = self .key_conflicts )
160+
161+ return sims
162+
163+ raise ValueError (
164+ "Key conflicts are found in model simulations and no valid `key_conflicts` policy was provided."
165+ )
166+
167+ def _determine_key_conflicts (self , sims ):
168+ # determine only once
169+ if self ._keys is not None :
170+ return self ._keys
171+
172+ keys = [set (sim .keys ()) for sim in sims ]
173+ all_keys = set .union (* keys )
174+ common_keys = set .intersection (* keys )
175+ missing_keys = [all_keys - k for k in keys ]
176+
177+ self ._keys = keys , all_keys , common_keys , missing_keys
178+
179+ if all_keys == common_keys :
180+ return self ._keys
181+
182+ if self .key_conflicts == "drop" :
183+ logging .info (
184+ f"Incompatible simulator output. \
185+ The following keys will be dropped: { ', ' .join (sorted (all_keys - common_keys ))} ."
186+ )
187+ elif isinstance (self .key_conflicts , (float , int )):
188+ logging .info (
189+ f"Incompatible simulator output. \
190+ Attempting to replace keys: { ', ' .join (sorted (all_keys - common_keys ))} , where missing, \
191+ with value { self .key_conflicts } ."
192+ )
193+
194+ return self ._keys
0 commit comments