From 97838a70314baf93a53e1b6ea7ec21c970341a72 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Tue, 11 Mar 2025 09:36:16 +0100 Subject: [PATCH] add docs to ModelComparisonApproximator --- .../model_comparison_approximator.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index 8c42e3cc6..7e9015fa2 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -22,6 +22,20 @@ class ModelComparisonApproximator(Approximator): """ Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are learned with a classifier. + + Parameters + ---------- + adapter: Adapter + Adapter for data processing. + num_models: int + Number of models (simulators) that the approximator will compare + classifier_network: keras.Model + The network (e.g, an MLP) that is used for model classification. + The input of the classifier network is created by concatenating `classifier_variables` + and (optional) output of the summary_network. + summary_network: SummaryNetwork, optional + The summary network used for data summarisation (default is None). + The input of the summary network is `summary_variables`. """ def __init__( @@ -161,6 +175,60 @@ def fit( simulators: Sequence[Simulator] = None, **kwargs, ): + """ + Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator. + If `dataset` is not provided, a dataset is built from the `simulator`. + If `simulator` is not provided, it will be build from a list of `simulators`. + If the model has not been built, it will be built using a batch from the dataset. + + Parameters + ---------- + dataset : keras.utils.PyDataset, optional + A dataset containing simulations for training. If provided, `simulator` must be None. + simulator : ModelComparisonSimulator, optional + A simulator used to generate a dataset. If provided, `dataset` must be None. + simulators: Sequence[Simulator], optional + A list of simulators (one simulator per model). If provided, `dataset` must be None. + **kwargs : dict + Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`): + + batch_size : int or None, default='auto' + Number of samples per gradient update. Do not specify if `dataset` is provided as a + `keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function. + epochs : int, default=1 + Number of epochs to train the model. + verbose : {"auto", 0, 1, 2}, default="auto" + Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. + callbacks : list of keras.callbacks.Callback, optional + List of callbacks to apply during training. + validation_split : float, optional + Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays + or tensors). + validation_data : tuple or dataset, optional + Data for validation, overriding `validation_split`. + shuffle : bool, default=True + Whether to shuffle the training data before each epoch (ignored for dataset generators). + initial_epoch : int, default=0 + Epoch at which to start training (useful for resuming training). + steps_per_epoch : int or None, optional + Number of steps (batches) before declaring an epoch finished. + validation_steps : int or None, optional + Number of validation steps per validation epoch. + validation_batch_size : int or None, optional + Number of samples per validation batch (defaults to `batch_size`). + validation_freq : int, default=1 + Specifies how many training epochs to run before performing validation. + + Returns + ------- + keras.callbacks.History + A history object containing the training loss and metrics values. + + Raises + ------ + ValueError + If both `dataset` and `simulator` or `simulators` are provided or neither is provided. + """ if dataset is not None: if simulator is not None or simulators is not None: raise ValueError( @@ -207,6 +275,25 @@ def predict( logits: bool = False, **kwargs, ) -> np.ndarray: + """ + Predicts posterior model probabilities given input conditions. The `conditions` dictionary is preprocessed + using the `adapter`. The output is converted to NumPy array after inference. + + Parameters + ---------- + conditions : dict[str, np.ndarray] + Dictionary of conditioning variables as NumPy arrays. + logits: bool, default=False + Should the posterior model probabilities be on the (unconstrained) logit space? + If `False`, the output is a unit simplex instead. + **kwargs : dict + Additional keyword arguments for the adapter and classification process. + + Returns + ------- + np.ndarray + Predicted posterior model probabilities given `conditions`. + """ conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs) # at inference time, model_indices are predicted by the networks and thus ignored in conditions conditions.pop("model_indices", None)