Skip to content

Commit db822a2

Browse files
committed
improve sampling (with explicit num_datasets for now)
1 parent 582a27b commit db822a2

File tree

3 files changed

+433
-800
lines changed

3 files changed

+433
-800
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bayesflow.adapters import Adapter
1212
from bayesflow.networks import InferenceNetwork, SummaryNetwork
1313
from bayesflow.types import Tensor
14-
from bayesflow.utils import filter_kwargs, logging, split_arrays
14+
from bayesflow.utils import filter_kwargs, logging, split_arrays, concatenate_valid
1515
from .approximator import Approximator
1616

1717

@@ -213,6 +213,7 @@ def get_config(self):
213213
def sample(
214214
self,
215215
*,
216+
num_datasets: int,
216217
num_samples: int,
217218
conditions: dict[str, np.ndarray],
218219
split: bool = False,
@@ -224,6 +225,8 @@ def sample(
224225
225226
Parameters
226227
----------
228+
num_datasets: int, optional
229+
Number of datasets to generate.
227230
num_samples : int
228231
Number of samples to generate.
229232
conditions : dict[str, np.ndarray]
@@ -239,17 +242,54 @@ def sample(
239242
dict[str, np.ndarray]
240243
Dictionary containing generated samples with the same keys as `conditions`.
241244
"""
242-
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
243-
# at inference time, inference_variables are estimated by the networks and thus ignored in conditions
244-
conditions.pop("inference_variables", None)
245-
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
246-
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
247-
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
248-
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
245+
conditions = self.adapter(conditions, strict=False)
246+
247+
if "inference_conditions" in conditions:
248+
inference_conditions = keras.ops.convert_to_tensor(conditions["inference_conditions"])
249+
else:
250+
inference_conditions = None
251+
252+
if "summary_conditions" in conditions:
253+
# we are directly supplied the summary network output
254+
if self.summary_network is None:
255+
raise ValueError("Cannot supply direct summary conditions without a summary network.")
256+
summary_conditions = keras.ops.convert_to_tensor(conditions["summary_conditions"])
257+
elif "summary_variables" in conditions:
258+
# we are supplied the summary network input
259+
if self.summary_network is None:
260+
raise ValueError("Cannot supply summary variables without a summary network.")
261+
summary_variables = keras.ops.convert_to_tensor(conditions["summary_variables"])
262+
summary_conditions = self.summary_network(summary_variables)
263+
else:
264+
# we are not supplied any summary statistics
265+
if self.summary_network is not None:
266+
raise ValueError("Summary conditions are required when a summary network is present.")
267+
summary_conditions = None
268+
269+
shape = [num_datasets, num_samples, keras.ops.shape(inference_conditions)[-1]]
270+
inference_conditions = keras.ops.broadcast_to(inference_conditions, shape)
271+
272+
shape = [num_datasets, num_samples, keras.ops.shape(summary_conditions)[-1]]
273+
summary_conditions = keras.ops.broadcast_to(summary_conditions, shape)
274+
275+
samples = self.inference_network.sample(
276+
(num_datasets, num_samples),
277+
conditions=concatenate_valid([inference_conditions, summary_conditions], axis=-1),
278+
**kwargs,
279+
)
280+
281+
samples = keras.ops.convert_to_numpy(samples)
282+
283+
samples = {
284+
"inference_variables": samples,
285+
**conditions,
286+
}
287+
samples = self.adapter(samples, inverse=True, strict=False)
249288

250289
if split:
251-
conditions = split_arrays(conditions, axis=-1)
252-
return conditions
290+
samples = split_arrays(samples, axis=-1)
291+
292+
return samples
253293

254294
def _sample(
255295
self,

examples/SIR_Posterior_Estimation.ipynb

Lines changed: 222 additions & 629 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)