Skip to content

Commit 39de287

Browse files
committed
check keys every time, issue warning only once
1 parent b486784 commit 39de287

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self.use_mixed_batches = use_mixed_batches
8484
self.key_conflicts = key_conflicts
8585
self.fill_value = fill_value
86-
self._keys = None
86+
self._key_conflicts_warning = True
8787

8888
@allow_batch_size
8989
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
@@ -163,30 +163,28 @@ def _handle_key_conflicts(self, sims, batch_sizes):
163163
raise ValueError("Key conflicts are found in simulator outputs, cannot combine them into one batch.")
164164

165165
def _determine_key_conflicts(self, sims):
166-
# determine only once
167-
if self._keys is not None:
168-
return self._keys
169-
170166
keys = [set(sim.keys()) for sim in sims]
171167
all_keys = set.union(*keys)
172168
common_keys = set.intersection(*keys)
173169
missing_keys = [all_keys - k for k in keys]
174170

175-
self._keys = keys, all_keys, common_keys, missing_keys
176-
177171
if all_keys == common_keys:
178-
return self._keys
172+
return keys, all_keys, common_keys, missing_keys
179173

180-
if self.key_conflicts == "drop":
181-
logging.info(
182-
f"Incompatible simulator output. \
174+
if self._key_conflicts_warning:
175+
# issue warning only once
176+
self._key_conflicts_warning = False
177+
178+
if self.key_conflicts == "drop":
179+
logging.info(
180+
f"Incompatible simulator output. \
183181
The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}."
184-
)
185-
elif self.key_conflicts == "fill":
186-
logging.info(
187-
f"Incompatible simulator output. \
182+
)
183+
elif self.key_conflicts == "fill":
184+
logging.info(
185+
f"Incompatible simulator output. \
188186
Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \
189187
with value {self.fill_value}."
190-
)
188+
)
191189

192-
return self._keys
190+
return keys, all_keys, common_keys, missing_keys

0 commit comments

Comments
 (0)