@@ -78,7 +78,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7878 ----------
7979 batch_shape : Shape
8080 The shape of the batch to sample. Typically, a tuple indicating the number of samples,
81- but can also be an int.
81+ but the user can also supply an int.
8282 **kwargs
8383 Additional keyword arguments passed to each simulator. These may include outputs from
8484 the shared simulator.
@@ -95,30 +95,26 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
9595 if self .shared_simulator :
9696 data |= self .shared_simulator .sample (batch_shape , ** kwargs )
9797
98- if not self .use_mixed_batches :
99- # draw one model index for the whole batch (faster)
100- model_index = np .random .choice (len (self .simulators ), p = npu .softmax (self .logits ))
98+ softmax_logits = npu .softmax (self .logits )
99+ num_models = len (self .simulators )
101100
102- simulator = self .simulators [model_index ]
103- data = simulator .sample (batch_shape , ** (kwargs | data ))
104-
105- model_indices = np .full (batch_shape , model_index , dtype = "int32" )
106- model_indices = npu .one_hot (model_indices , len (self .simulators ))
107- else :
108- # generate data randomly from each model (slower)
109- model_counts = np .random .multinomial (n = batch_shape [0 ], pvals = npu .softmax (self .logits ))
110-
111- sims = []
112- for n , simulator in zip (model_counts , self .simulators ):
113- if n == 0 :
114- continue
115- sim = simulator .sample (n , ** (kwargs | data ))
116- sims .append (sim )
101+ # generate data randomly from each model (slower)
102+ if self .use_mixed_batches :
103+ model_counts = np .random .multinomial (n = batch_shape [0 ], pvals = softmax_logits )
117104
105+ sims = [
106+ simulator .sample (n , ** (kwargs | data )) for simulator , n in zip (self .simulators , model_counts ) if n > 0
107+ ]
118108 sims = tree_concatenate (sims , numpy = True )
119109 data |= sims
120110
121- model_indices = np .eye (len (self .simulators ), dtype = "int32" )
122- model_indices = np .repeat (model_indices , model_counts , axis = 0 )
111+ model_indices = np .repeat (np .eye (num_models , dtype = "int32" ), model_counts , axis = 0 )
112+
113+ # draw one model index for the whole batch (faster)
114+ else :
115+ model_index = np .random .choice (num_models , p = softmax_logits )
116+
117+ data = self .simulators [model_index ].sample (batch_shape , ** (kwargs | data ))
118+ model_indices = npu .one_hot (np .full (batch_shape , model_index , dtype = "int32" ), num_models )
123119
124120 return data | {"model_indices" : model_indices }
0 commit comments