Skip to content

Commit 97838a7

Browse files
committed
add docs to ModelComparisonApproximator
1 parent 3ad23d6 commit 97838a7

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@ class ModelComparisonApproximator(Approximator):
2222
"""
2323
Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are
2424
learned with a classifier.
25+
26+
Parameters
27+
----------
28+
adapter: Adapter
29+
Adapter for data processing.
30+
num_models: int
31+
Number of models (simulators) that the approximator will compare
32+
classifier_network: keras.Model
33+
The network (e.g, an MLP) that is used for model classification.
34+
The input of the classifier network is created by concatenating `classifier_variables`
35+
and (optional) output of the summary_network.
36+
summary_network: SummaryNetwork, optional
37+
The summary network used for data summarisation (default is None).
38+
The input of the summary network is `summary_variables`.
2539
"""
2640

2741
def __init__(
@@ -161,6 +175,60 @@ def fit(
161175
simulators: Sequence[Simulator] = None,
162176
**kwargs,
163177
):
178+
"""
179+
Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator.
180+
If `dataset` is not provided, a dataset is built from the `simulator`.
181+
If `simulator` is not provided, it will be build from a list of `simulators`.
182+
If the model has not been built, it will be built using a batch from the dataset.
183+
184+
Parameters
185+
----------
186+
dataset : keras.utils.PyDataset, optional
187+
A dataset containing simulations for training. If provided, `simulator` must be None.
188+
simulator : ModelComparisonSimulator, optional
189+
A simulator used to generate a dataset. If provided, `dataset` must be None.
190+
simulators: Sequence[Simulator], optional
191+
A list of simulators (one simulator per model). If provided, `dataset` must be None.
192+
**kwargs : dict
193+
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
194+
195+
batch_size : int or None, default='auto'
196+
Number of samples per gradient update. Do not specify if `dataset` is provided as a
197+
`keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
198+
epochs : int, default=1
199+
Number of epochs to train the model.
200+
verbose : {"auto", 0, 1, 2}, default="auto"
201+
Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
202+
callbacks : list of keras.callbacks.Callback, optional
203+
List of callbacks to apply during training.
204+
validation_split : float, optional
205+
Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
206+
or tensors).
207+
validation_data : tuple or dataset, optional
208+
Data for validation, overriding `validation_split`.
209+
shuffle : bool, default=True
210+
Whether to shuffle the training data before each epoch (ignored for dataset generators).
211+
initial_epoch : int, default=0
212+
Epoch at which to start training (useful for resuming training).
213+
steps_per_epoch : int or None, optional
214+
Number of steps (batches) before declaring an epoch finished.
215+
validation_steps : int or None, optional
216+
Number of validation steps per validation epoch.
217+
validation_batch_size : int or None, optional
218+
Number of samples per validation batch (defaults to `batch_size`).
219+
validation_freq : int, default=1
220+
Specifies how many training epochs to run before performing validation.
221+
222+
Returns
223+
-------
224+
keras.callbacks.History
225+
A history object containing the training loss and metrics values.
226+
227+
Raises
228+
------
229+
ValueError
230+
If both `dataset` and `simulator` or `simulators` are provided or neither is provided.
231+
"""
164232
if dataset is not None:
165233
if simulator is not None or simulators is not None:
166234
raise ValueError(
@@ -207,6 +275,25 @@ def predict(
207275
logits: bool = False,
208276
**kwargs,
209277
) -> np.ndarray:
278+
"""
279+
Predicts posterior model probabilities given input conditions. The `conditions` dictionary is preprocessed
280+
using the `adapter`. The output is converted to NumPy array after inference.
281+
282+
Parameters
283+
----------
284+
conditions : dict[str, np.ndarray]
285+
Dictionary of conditioning variables as NumPy arrays.
286+
logits: bool, default=False
287+
Should the posterior model probabilities be on the (unconstrained) logit space?
288+
If `False`, the output is a unit simplex instead.
289+
**kwargs : dict
290+
Additional keyword arguments for the adapter and classification process.
291+
292+
Returns
293+
-------
294+
np.ndarray
295+
Predicted posterior model probabilities given `conditions`.
296+
"""
210297
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
211298
# at inference time, model_indices are predicted by the networks and thus ignored in conditions
212299
conditions.pop("model_indices", None)

0 commit comments

Comments
 (0)