@@ -22,6 +22,20 @@ class ModelComparisonApproximator(Approximator):
2222 """
2323 Defines an approximator for model (simulator) comparison, where the (discrete) posterior model probabilities are
2424 learned with a classifier.
25+
26+ Parameters
27+ ----------
28+ adapter: Adapter
29+ Adapter for data processing.
30+ num_models: int
31+ Number of models (simulators) that the approximator will compare
32+ classifier_network: keras.Model
33+ The network (e.g, an MLP) that is used for model classification.
34+ The input of the classifier network is created by concatenating `classifier_variables`
35+ and (optional) output of the summary_network.
36+ summary_network: SummaryNetwork, optional
37+ The summary network used for data summarisation (default is None).
38+ The input of the summary network is `summary_variables`.
2539 """
2640
2741 def __init__ (
@@ -161,6 +175,60 @@ def fit(
161175 simulators : Sequence [Simulator ] = None ,
162176 ** kwargs ,
163177 ):
178+ """
179+ Trains the approximator on the provided dataset or on-demand generated from the given (multi-model) simulator.
180+ If `dataset` is not provided, a dataset is built from the `simulator`.
181+ If `simulator` is not provided, it will be build from a list of `simulators`.
182+ If the model has not been built, it will be built using a batch from the dataset.
183+
184+ Parameters
185+ ----------
186+ dataset : keras.utils.PyDataset, optional
187+ A dataset containing simulations for training. If provided, `simulator` must be None.
188+ simulator : ModelComparisonSimulator, optional
189+ A simulator used to generate a dataset. If provided, `dataset` must be None.
190+ simulators: Sequence[Simulator], optional
191+ A list of simulators (one simulator per model). If provided, `dataset` must be None.
192+ **kwargs : dict
193+ Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
194+
195+ batch_size : int or None, default='auto'
196+ Number of samples per gradient update. Do not specify if `dataset` is provided as a
197+ `keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
198+ epochs : int, default=1
199+ Number of epochs to train the model.
200+ verbose : {"auto", 0, 1, 2}, default="auto"
201+ Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
202+ callbacks : list of keras.callbacks.Callback, optional
203+ List of callbacks to apply during training.
204+ validation_split : float, optional
205+ Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
206+ or tensors).
207+ validation_data : tuple or dataset, optional
208+ Data for validation, overriding `validation_split`.
209+ shuffle : bool, default=True
210+ Whether to shuffle the training data before each epoch (ignored for dataset generators).
211+ initial_epoch : int, default=0
212+ Epoch at which to start training (useful for resuming training).
213+ steps_per_epoch : int or None, optional
214+ Number of steps (batches) before declaring an epoch finished.
215+ validation_steps : int or None, optional
216+ Number of validation steps per validation epoch.
217+ validation_batch_size : int or None, optional
218+ Number of samples per validation batch (defaults to `batch_size`).
219+ validation_freq : int, default=1
220+ Specifies how many training epochs to run before performing validation.
221+
222+ Returns
223+ -------
224+ keras.callbacks.History
225+ A history object containing the training loss and metrics values.
226+
227+ Raises
228+ ------
229+ ValueError
230+ If both `dataset` and `simulator` or `simulators` are provided or neither is provided.
231+ """
164232 if dataset is not None :
165233 if simulator is not None or simulators is not None :
166234 raise ValueError (
@@ -207,6 +275,25 @@ def predict(
207275 logits : bool = False ,
208276 ** kwargs ,
209277 ) -> np .ndarray :
278+ """
279+ Predicts posterior model probabilities given input conditions. The `conditions` dictionary is preprocessed
280+ using the `adapter`. The output is converted to NumPy array after inference.
281+
282+ Parameters
283+ ----------
284+ conditions : dict[str, np.ndarray]
285+ Dictionary of conditioning variables as NumPy arrays.
286+ logits: bool, default=False
287+ Should the posterior model probabilities be on the (unconstrained) logit space?
288+ If `False`, the output is a unit simplex instead.
289+ **kwargs : dict
290+ Additional keyword arguments for the adapter and classification process.
291+
292+ Returns
293+ -------
294+ np.ndarray
295+ Predicted posterior model probabilities given `conditions`.
296+ """
210297 conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
211298 # at inference time, model_indices are predicted by the networks and thus ignored in conditions
212299 conditions .pop ("model_indices" , None )
0 commit comments