Skip to content

Commit eb35a14

Browse files
committed
Include shared context in MultiSimulationDataset for offline training
1 parent 225a817 commit eb35a14

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

bayesflow/helper_classes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,18 @@ def __init__(self, forward_dict, batch_size, buffer_size=1024):
141141
self.iters = [iter(d) for d in self.datasets]
142142
self.batch_size = batch_size
143143

144+
# Include further keys (= shared context) from forward_dict
145+
self.further_keys = {}
146+
for key, value in forward_dict.items():
147+
if key not in [DEFAULT_KEYS["model_outputs"], DEFAULT_KEYS["model_indices"]]:
148+
self.further_keys[key] = value
149+
144150
def __next__(self):
145151
if self.current_it < self.num_batches:
146152
outputs = [next(d) for d in self.iters]
147153
output_dict = {DEFAULT_KEYS["model_outputs"]: outputs, DEFAULT_KEYS["model_indices"]: self.model_indices}
154+
if self.further_keys:
155+
output_dict.update(self.further_keys)
148156
self.current_it += 1
149157
return output_dict
150158
self.current_it = 0

0 commit comments

Comments
 (0)