Skip to content

Commit cc4db07

Browse files
committed
Small improvements to model comp simulator
1 parent 45c43f2 commit cc4db07

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
7878
----------
7979
batch_shape : Shape
8080
The shape of the batch to sample. Typically, a tuple indicating the number of samples,
81-
but can also be an int.
81+
but the user can also supply an int.
8282
**kwargs
8383
Additional keyword arguments passed to each simulator. These may include outputs from
8484
the shared simulator.
@@ -95,30 +95,26 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
9595
if self.shared_simulator:
9696
data |= self.shared_simulator.sample(batch_shape, **kwargs)
9797

98-
if not self.use_mixed_batches:
99-
# draw one model index for the whole batch (faster)
100-
model_index = np.random.choice(len(self.simulators), p=npu.softmax(self.logits))
98+
softmax_logits = npu.softmax(self.logits)
99+
num_models = len(self.simulators)
101100

102-
simulator = self.simulators[model_index]
103-
data = simulator.sample(batch_shape, **(kwargs | data))
104-
105-
model_indices = np.full(batch_shape, model_index, dtype="int32")
106-
model_indices = npu.one_hot(model_indices, len(self.simulators))
107-
else:
108-
# generate data randomly from each model (slower)
109-
model_counts = np.random.multinomial(n=batch_shape[0], pvals=npu.softmax(self.logits))
110-
111-
sims = []
112-
for n, simulator in zip(model_counts, self.simulators):
113-
if n == 0:
114-
continue
115-
sim = simulator.sample(n, **(kwargs | data))
116-
sims.append(sim)
101+
# generate data randomly from each model (slower)
102+
if self.use_mixed_batches:
103+
model_counts = np.random.multinomial(n=batch_shape[0], pvals=softmax_logits)
117104

105+
sims = [
106+
simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0
107+
]
118108
sims = tree_concatenate(sims, numpy=True)
119109
data |= sims
120110

121-
model_indices = np.eye(len(self.simulators), dtype="int32")
122-
model_indices = np.repeat(model_indices, model_counts, axis=0)
111+
model_indices = np.repeat(np.eye(num_models, dtype="int32"), model_counts, axis=0)
112+
113+
# draw one model index for the whole batch (faster)
114+
else:
115+
model_index = np.random.choice(num_models, p=softmax_logits)
116+
117+
data = self.simulators[model_index].sample(batch_shape, **(kwargs | data))
118+
model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models)
123119

124120
return data | {"model_indices": model_indices}

0 commit comments

Comments
 (0)