@@ -23,7 +23,8 @@ def __init__(
2323 p : Sequence [float ] = None ,
2424 logits : Sequence [float ] = None ,
2525 use_mixed_batches : bool = True ,
26- key_conflicts : str | float = "drop" ,
26+ key_conflicts : str = "drop" ,
27+ fill_value : float = np .nan ,
2728 shared_simulator : Simulator | Callable [[Sequence [int ]], dict [str , any ]] = None ,
2829 ):
2930 """
@@ -43,11 +44,13 @@ def __init__(
4344 Whether to draw samples in a batch from different models.
4445 - If True (default), each sample in a batch may come from a different model.
4546 - If False, the entire batch is drawn from a single model, selected according to model probabilities.
46- key_conflicts : {"drop"} | float , optional
47+ key_conflicts : str , optional
4748 Policy for handling keys that are missing in the output of some models, when using mixed batches.
4849 - "drop" (default): Drop conflicting keys from the batch output.
49- - float: Fill missing keys with the specified value.
50- - If neither "drop" nor a float is given, an error is raised when key conflicts are detected.
50+ - "fill": Fill missing keys with the specified value.
51+ - "error": An error is raised when key conflicts are detected.
52+ fill_value : float, optional
53+ If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument.
5154 shared_simulator : Simulator or Callable, optional
5255 A shared simulator whose outputs are passed to all model simulators. If a function is
5356 provided, it is wrapped in a `LambdaSimulator` with batching enabled.
@@ -77,6 +80,7 @@ def __init__(
7780 self .logits = logits
7881 self .use_mixed_batches = use_mixed_batches
7982 self .key_conflicts = key_conflicts
83+ self .fill_value = fill_value
8084 self ._keys = None
8185
8286 @allow_batch_size
@@ -139,30 +143,22 @@ def _handle_key_conflicts(self, sims, batch_sizes):
139143 if all_keys == common_keys :
140144 return sims
141145
142- # keep only common keys
143146 if self .key_conflicts == "drop" :
144147 sims = [{k : v for k , v in sim .items () if k in common_keys } for sim in sims ]
145148 return sims
146-
147- # try to fill with key_conflicts to shape of the values from other model
148- if isinstance (self .key_conflicts , (float , int )):
149+ elif self .key_conflicts == "fill" :
149150 combined_sims = {}
150151 for sim in sims :
151152 combined_sims = combined_sims | sim
152-
153153 for i , sim in enumerate (sims ):
154154 for missing_key in missing_keys [i ]:
155155 shape = combined_sims [missing_key ].shape
156156 shape = list (shape )
157157 shape [0 ] = batch_sizes [i ]
158-
159- sim [missing_key ] = np .full (shape = shape , fill_value = self .key_conflicts )
160-
158+ sim [missing_key ] = np .full (shape = shape , fill_value = self .fill_value )
161159 return sims
162-
163- raise ValueError (
164- "Key conflicts are found in model simulations and no valid `key_conflicts` policy was provided."
165- )
160+ elif self .key_conflicts == "error" :
161+ raise ValueError ("Key conflicts are found in simulator outputs, cannot combine them into one batch." )
166162
167163 def _determine_key_conflicts (self , sims ):
168164 # determine only once
@@ -184,11 +180,11 @@ def _determine_key_conflicts(self, sims):
184180 f"Incompatible simulator output. \
185181 The following keys will be dropped: { ', ' .join (sorted (all_keys - common_keys ))} ."
186182 )
187- elif isinstance ( self .key_conflicts , ( float , int )) :
183+ elif self .key_conflicts == "fill" :
188184 logging .info (
189185 f"Incompatible simulator output. \
190186 Attempting to replace keys: { ', ' .join (sorted (all_keys - common_keys ))} , where missing, \
191- with value { self .key_conflicts } ."
187+ with value { self .fill_value } ."
192188 )
193189
194190 return self ._keys
0 commit comments