@@ -111,6 +111,56 @@ def compute_metrics(
111111 return metrics
112112
113113 def fit (self , * args , ** kwargs ):
114+ """
115+ Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
116+ If `dataset` is not provided, a dataset is built from the `simulator`.
117+ If the model has not been built, it will be built using a batch from the dataset.
118+
119+ Parameters
120+ ----------
121+ dataset : keras.utils.PyDataset, optional
122+ A dataset containing simulations for training. If provided, `simulator` must be None.
123+ simulator : Simulator, optional
124+ A simulator used to generate a dataset. If provided, `dataset` must be None.
125+ **kwargs : dict
126+ Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
127+ batch_size : int or None, default='auto'
128+ Number of samples per gradient update. Do not specify if `dataset` is provided as a
129+ `keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
130+ epochs : int, default=1
131+ Number of epochs to train the model.
132+ verbose : {"auto", 0, 1, 2}, default="auto"
133+ Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
134+ callbacks : list of keras.callbacks.Callback, optional
135+ List of callbacks to apply during training.
136+ validation_split : float, optional
137+ Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
138+ or tensors).
139+ validation_data : tuple or dataset, optional
140+ Data for validation, overriding `validation_split`.
141+ shuffle : bool, default=True
142+ Whether to shuffle the training data before each epoch (ignored for dataset generators).
143+ initial_epoch : int, default=0
144+ Epoch at which to start training (useful for resuming training).
145+ steps_per_epoch : int or None, optional
146+ Number of steps (batches) before declaring an epoch finished.
147+ validation_steps : int or None, optional
148+ Number of validation steps per validation epoch.
149+ validation_batch_size : int or None, optional
150+ Number of samples per validation batch (defaults to `batch_size`).
151+ validation_freq : int, default=1
152+ Specifies how many training epochs to run before performing validation.
153+
154+ Returns
155+ -------
156+ keras.callbacks.History
157+ A history object containing the training loss and metrics values.
158+
159+ Raises
160+ ------
161+ ValueError
162+ If both `dataset` and `simulator` are provided or neither is provided.
163+ """
114164 return super ().fit (* args , ** kwargs , adapter = self .adapter )
115165
116166 @classmethod
@@ -139,6 +189,27 @@ def sample(
139189 split : bool = False ,
140190 ** kwargs ,
141191 ) -> dict [str , np .ndarray ]:
192+ """
193+ Generates samples from the approximator given input conditions. The `conditions` dictionary is preprocessed
194+ using the `adapter`. Samples are converted to NumPy arrays after inference.
195+
196+ Parameters
197+ ----------
198+ num_samples : int
199+ Number of samples to generate.
200+ conditions : dict[str, np.ndarray]
201+ Dictionary of conditioning variables as NumPy arrays.
202+ split : bool, default=False
203+ Whether to split the output arrays along the last axis and return one column vector per target variable
204+ samples.
205+ **kwargs : dict
206+ Additional keyword arguments for the adapter and sampling process.
207+
208+ Returns
209+ -------
210+ dict[str, np.ndarray]
211+ Dictionary containing generated samples with the same keys as `conditions`.
212+ """
142213 conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
143214 # at inference time, inference_variables are estimated by the networks and thus ignored in conditions
144215 conditions .pop ("inference_variables" , None )
@@ -192,6 +263,22 @@ def _sample(
192263 )
193264
194265 def log_prob (self , data : dict [str , np .ndarray ], ** kwargs ) -> np .ndarray :
266+ """
267+ Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the
268+ `adapter`. Log-probabilities are returned as NumPy arrays.
269+
270+ Parameters
271+ ----------
272+ data : dict[str, np.ndarray]
273+ Dictionary of observed data as NumPy arrays.
274+ **kwargs : dict
275+ Additional keyword arguments for the adapter and log-probability computation.
276+
277+ Returns
278+ -------
279+ np.ndarray
280+ Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
281+ """
195282 data = self .adapter (data , strict = False , stage = "inference" , ** kwargs )
196283 data = keras .tree .map_structure (keras .ops .convert_to_tensor , data )
197284 log_prob = self ._log_prob (** data , ** kwargs )
0 commit comments