1010class SequentialSimulator (Simulator ):
1111 """Combines multiple simulators into one, sequentially."""
1212
13- def __init__ (self , simulators : Sequence [Simulator ], expand_outputs : bool = True ):
13+ def __init__ (self , simulators : Sequence [Simulator ], expand_outputs : bool = True , replace_inputs : bool = True ):
1414 """
1515 Initialize a SequentialSimulator.
1616
@@ -22,10 +22,13 @@ def __init__(self, simulators: Sequence[Simulator], expand_outputs: bool = True)
2222 expand_outputs : bool, optional
2323 If True, 1D output arrays are expanded with an additional dimension at the end.
2424 Default is True.
25+ replace_inputs : bool, optional
26+ If True, **kwargs are auto-batched and replace simulator outputs.
2527 """
2628
2729 self .simulators = simulators
2830 self .expand_outputs = expand_outputs
31+ self .replace_inputs = replace_inputs
2932
3033 @allow_batch_size
3134 def sample (self , batch_shape : Shape , ** kwargs ) -> dict [str , np .ndarray ]:
@@ -53,6 +56,14 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
5356 for simulator in self .simulators :
5457 data |= simulator .sample (batch_shape , ** (kwargs | data ))
5558
59+ if self .replace_inputs :
60+ common_keys = set (data .keys ()) & set (kwargs .keys ())
61+ for key in common_keys :
62+ value = kwargs .pop (key )
63+ if isinstance (data [key ], np .ndarray ):
64+ value = np .broadcast_to (value , data [key ].shape )
65+ data [key ] = value
66+
5667 if self .expand_outputs :
5768 data = {
5869 key : np .expand_dims (value , axis = - 1 ) if np .ndim (value ) == 1 else value for key , value in data .items ()
0 commit comments