Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ class ModelComparisonApproximator(Approximator):
"""
Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are
learned with a classifier.

Parameters
----------
adapter: Adapter
Adapter for data processing.
num_models: int
Number of models (simulators) that the approximator will compare
classifier_network: keras.Model
The network (e.g, an MLP) that is used for model classification.
The input of the classifier network is created by concatenating `classifier_variables`
and (optional) output of the summary_network.
summary_network: SummaryNetwork, optional
The summary network used for data summarisation (default is None).
The input of the summary network is `summary_variables`.
"""

def __init__(
Expand Down Expand Up @@ -161,6 +175,60 @@ def fit(
simulators: Sequence[Simulator] = None,
**kwargs,
):
"""
Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator.
If `dataset` is not provided, a dataset is built from the `simulator`.
If `simulator` is not provided, it will be build from a list of `simulators`.
If the model has not been built, it will be built using a batch from the dataset.

Parameters
----------
dataset : keras.utils.PyDataset, optional
A dataset containing simulations for training. If provided, `simulator` must be None.
simulator : ModelComparisonSimulator, optional
A simulator used to generate a dataset. If provided, `dataset` must be None.
simulators: Sequence[Simulator], optional
A list of simulators (one simulator per model). If provided, `dataset` must be None.
**kwargs : dict
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):

batch_size : int or None, default='auto'
Number of samples per gradient update. Do not specify if `dataset` is provided as a
`keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
epochs : int, default=1
Number of epochs to train the model.
verbose : {"auto", 0, 1, 2}, default="auto"
Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks : list of keras.callbacks.Callback, optional
List of callbacks to apply during training.
validation_split : float, optional
Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
or tensors).
validation_data : tuple or dataset, optional
Data for validation, overriding `validation_split`.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch (ignored for dataset generators).
initial_epoch : int, default=0
Epoch at which to start training (useful for resuming training).
steps_per_epoch : int or None, optional
Number of steps (batches) before declaring an epoch finished.
validation_steps : int or None, optional
Number of validation steps per validation epoch.
validation_batch_size : int or None, optional
Number of samples per validation batch (defaults to `batch_size`).
validation_freq : int, default=1
Specifies how many training epochs to run before performing validation.

Returns
-------
keras.callbacks.History
A history object containing the training loss and metrics values.

Raises
------
ValueError
If both `dataset` and `simulator` or `simulators` are provided or neither is provided.
"""
if dataset is not None:
if simulator is not None or simulators is not None:
raise ValueError(
Expand Down Expand Up @@ -207,6 +275,25 @@ def predict(
logits: bool = False,
**kwargs,
) -> np.ndarray:
"""
Predicts posterior model probabilities given input conditions. The `conditions` dictionary is preprocessed
using the `adapter`. The output is converted to NumPy array after inference.

Parameters
----------
conditions : dict[str, np.ndarray]
Dictionary of conditioning variables as NumPy arrays.
logits: bool, default=False
Should the posterior model probabilities be on the (unconstrained) logit space?
If `False`, the output is a unit simplex instead.
**kwargs : dict
Additional keyword arguments for the adapter and classification process.

Returns
-------
np.ndarray
Predicted posterior model probabilities given `conditions`.
"""
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
# at inference time, model_indices are predicted by the networks and thus ignored in conditions
conditions.pop("model_indices", None)
Expand Down