@@ -24,6 +24,26 @@ def __init__(
2424 use_mixed_batches : bool = True ,
2525 shared_simulator : Simulator | FunctionType = None ,
2626 ):
27+ """
28+ Initialize a multi-model simulator that can generate data for mixture / model comparison problems.
29+
30+ Parameters
31+ ----------
32+ simulators : Sequence[Simulator]
33+ A sequence of simulator instances, each representing a different model.
34+ p : Sequence[float], optional
35+ A sequence of probabilities associated with each simulator. Must sum to 1.
36+ Mutually exclusive with `logits`.
37+ logits : Sequence[float], optional
38+ A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`.
39+ If neither `p` nor `logits` is provided, defaults to uniform logits.
40+ use_mixed_batches : bool, optional
41+ If True, samples in a batch are drawn from different models. If False, the entire batch
42+ is drawn from a single model chosen according to the model probabilities. Default is True.
43+ shared_simulator : Simulator or FunctionType, optional
44+ A shared simulator whose outputs are passed to all model simulators. If a function is
45+ provided, it is wrapped in a `LambdaSimulator` with batching enabled.
46+ """
2747 self .simulators = simulators
2848
2949 if isinstance (shared_simulator , FunctionType ):
@@ -51,34 +71,50 @@ def __init__(
5171
5272 @allow_batch_size
5373 def sample (self , batch_shape : Shape , ** kwargs ) -> dict [str , np .ndarray ]:
74+ """
75+ Sample from the model comparison simulator.
76+
77+ Parameters
78+ ----------
79+ batch_shape : Shape
80+ The shape of the batch to sample. Typically, a tuple indicating the number of samples,
81+ but the user can also supply an int.
82+ **kwargs
83+ Additional keyword arguments passed to each simulator. These may include outputs from
84+ the shared simulator.
85+
86+ Returns
87+ -------
88+ data : dict of str to np.ndarray
89+ A dictionary containing the sampled outputs. Includes:
90+ - outputs from the selected simulator(s)
91+ - optionally, outputs from the shared simulator
92+ - "model_indices": a one-hot encoded array indicating the model origin of each sample
93+ """
5494 data = {}
5595 if self .shared_simulator :
5696 data |= self .shared_simulator .sample (batch_shape , ** kwargs )
5797
58- if not self .use_mixed_batches :
59- # draw one model index for the whole batch (faster)
60- model_index = np .random .choice (len (self .simulators ), p = npu .softmax (self .logits ))
98+ softmax_logits = npu .softmax (self .logits )
99+ num_models = len (self .simulators )
61100
62- simulator = self .simulators [model_index ]
63- data = simulator .sample (batch_shape , ** (kwargs | data ))
64-
65- model_indices = np .full (batch_shape , model_index , dtype = "int32" )
66- model_indices = npu .one_hot (model_indices , len (self .simulators ))
67- else :
68- # generate data randomly from each model (slower)
69- model_counts = np .random .multinomial (n = batch_shape [0 ], pvals = npu .softmax (self .logits ))
70-
71- sims = []
72- for n , simulator in zip (model_counts , self .simulators ):
73- if n == 0 :
74- continue
75- sim = simulator .sample (n , ** (kwargs | data ))
76- sims .append (sim )
101+ # generate data randomly from each model (slower)
102+ if self .use_mixed_batches :
103+ model_counts = np .random .multinomial (n = batch_shape [0 ], pvals = softmax_logits )
77104
105+ sims = [
106+ simulator .sample (n , ** (kwargs | data )) for simulator , n in zip (self .simulators , model_counts ) if n > 0
107+ ]
78108 sims = tree_concatenate (sims , numpy = True )
79109 data |= sims
80110
81- model_indices = np .eye (len (self .simulators ), dtype = "int32" )
82- model_indices = np .repeat (model_indices , model_counts , axis = 0 )
111+ model_indices = np .repeat (np .eye (num_models , dtype = "int32" ), model_counts , axis = 0 )
112+
113+ # draw one model index for the whole batch (faster)
114+ else :
115+ model_index = np .random .choice (num_models , p = softmax_logits )
116+
117+ data = self .simulators [model_index ].sample (batch_shape , ** (kwargs | data ))
118+ model_indices = npu .one_hot (np .full (batch_shape , model_index , dtype = "int32" ), num_models )
83119
84120 return data | {"model_indices" : model_indices }
0 commit comments