@@ -83,7 +83,7 @@ def __init__(
8383 self .use_mixed_batches = use_mixed_batches
8484 self .key_conflicts = key_conflicts
8585 self .fill_value = fill_value
86- self ._keys = None
86+ self ._key_conflicts_warning = True
8787
8888 @allow_batch_size
8989 def sample (self , batch_shape : Shape , ** kwargs ) -> dict [str , np .ndarray ]:
@@ -163,30 +163,28 @@ def _handle_key_conflicts(self, sims, batch_sizes):
163163 raise ValueError ("Key conflicts are found in simulator outputs, cannot combine them into one batch." )
164164
165165 def _determine_key_conflicts (self , sims ):
166- # determine only once
167- if self ._keys is not None :
168- return self ._keys
169-
170166 keys = [set (sim .keys ()) for sim in sims ]
171167 all_keys = set .union (* keys )
172168 common_keys = set .intersection (* keys )
173169 missing_keys = [all_keys - k for k in keys ]
174170
175- self ._keys = keys , all_keys , common_keys , missing_keys
176-
177171 if all_keys == common_keys :
178- return self . _keys
172+ return keys , all_keys , common_keys , missing_keys
179173
180- if self .key_conflicts == "drop" :
181- logging .info (
182- f"Incompatible simulator output. \
174+ if self ._key_conflicts_warning :
175+ # issue warning only once
176+ self ._key_conflicts_warning = False
177+
178+ if self .key_conflicts == "drop" :
179+ logging .info (
180+ f"Incompatible simulator output. \
183181 The following keys will be dropped: { ', ' .join (sorted (all_keys - common_keys ))} ."
184- )
185- elif self .key_conflicts == "fill" :
186- logging .info (
187- f"Incompatible simulator output. \
182+ )
183+ elif self .key_conflicts == "fill" :
184+ logging .info (
185+ f"Incompatible simulator output. \
188186 Attempting to replace keys: { ', ' .join (sorted (all_keys - common_keys ))} , where missing, \
189187 with value { self .fill_value } ."
190- )
188+ )
191189
192- return self . _keys
190+ return keys , all_keys , common_keys , missing_keys
0 commit comments