Skip to content

Commit 01e56ea

Browse files
committed
overriding simulator outputs with auto-batched inputs
1 parent dc5ee17 commit 01e56ea

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

bayesflow/simulators/sequential_simulator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class SequentialSimulator(Simulator):
1111
"""Combines multiple simulators into one, sequentially."""
1212

13-
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True):
13+
def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True, replace_inputs: bool = True):
1414
"""
1515
Initialize a SequentialSimulator.
1616
@@ -22,10 +22,13 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True)
2222
expand_outputs : bool, optional
2323
If True, 1D output arrays are expanded with an additional dimension at the end.
2424
Default is True.
25+
replace_inputs : bool, optional
26+
If True, **kwargs are auto-batched and replace simulator outputs.
2527
"""
2628

2729
self.simulators = simulators
2830
self.expand_outputs = expand_outputs
31+
self.replace_inputs = replace_inputs
2932

3033
@allow_batch_size
3134
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
@@ -53,6 +56,14 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
5356
for simulator in self.simulators:
5457
data |= simulator.sample(batch_shape, **(kwargs | data))
5558

59+
if self.replace_inputs:
60+
common_keys = set(data.keys()) & set(kwargs.keys())
61+
for key in common_keys:
62+
value = kwargs.pop(key)
63+
if isinstance(data[key], np.ndarray):
64+
value = np.broadcast_to(value, data[key].shape)
65+
data[key] = value
66+
5667
if self.expand_outputs:
5768
data = {
5869
key: np.expand_dims(value, axis=-1) if np.ndim(value) == 1 else value for key, value in data.items()

0 commit comments

Comments
 (0)