Skip to content

Commit 37813b6

Browse files
committed
Add documentation to approximator
1 parent e49134a commit 37813b6

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

bayesflow/approximators/approximator.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,57 @@ def build_dataset(
6161
)
6262

6363
def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
64+
"""
65+
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
66+
If `dataset` is not provided, a dataset is built from the `simulator`.
67+
If the model has not been built, it will be built using a batch from the dataset.
68+
69+
Parameters
70+
----------
71+
dataset : keras.utils.PyDataset, optional
72+
A dataset containing simulations for training. If provided, `simulator` must be None.
73+
simulator : Simulator, optional
74+
A simulator used to generate a dataset. If provided, `dataset` must be None.
75+
**kwargs : dict
76+
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
77+
batch_size : int or None, default='auto'
78+
Number of samples per gradient update. Do not specify if `dataset` is provided as a
79+
`keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
80+
epochs : int, default=1
81+
Number of epochs to train the model.
82+
verbose : {"auto", 0, 1, 2}, default="auto"
83+
Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
84+
callbacks : list of keras.callbacks.Callback, optional
85+
List of callbacks to apply during training.
86+
validation_split : float, optional
87+
Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
88+
or tensors).
89+
validation_data : tuple or dataset, optional
90+
Data for validation, overriding `validation_split`.
91+
shuffle : bool, default=True
92+
Whether to shuffle the training data before each epoch (ignored for dataset generators).
93+
initial_epoch : int, default=0
94+
Epoch at which to start training (useful for resuming training).
95+
steps_per_epoch : int or None, optional
96+
Number of steps (batches) before declaring an epoch finished.
97+
validation_steps : int or None, optional
98+
Number of validation steps per validation epoch.
99+
validation_batch_size : int or None, optional
100+
Number of samples per validation batch (defaults to `batch_size`).
101+
validation_freq : int, default=1
102+
Specifies how many training epochs to run before performing validation.
103+
104+
Returns
105+
-------
106+
keras.callbacks.History
107+
A history object containing the training loss and metrics values.
108+
109+
Raises
110+
------
111+
ValueError
112+
If both `dataset` and `simulator` are provided or neither is provided.
113+
"""
114+
64115
if dataset is None:
65116
if simulator is None:
66117
raise ValueError("Received no data to fit on. Please provide either a dataset or a simulator.")

bayesflow/approximators/continuous_approximator.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,56 @@ def compute_metrics(
111111
return metrics
112112

113113
def fit(self, *args, **kwargs):
114+
"""
115+
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
116+
If `dataset` is not provided, a dataset is built from the `simulator`.
117+
If the model has not been built, it will be built using a batch from the dataset.
118+
119+
Parameters
120+
----------
121+
dataset : keras.utils.PyDataset, optional
122+
A dataset containing simulations for training. If provided, `simulator` must be None.
123+
simulator : Simulator, optional
124+
A simulator used to generate a dataset. If provided, `dataset` must be None.
125+
**kwargs : dict
126+
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
127+
batch_size : int or None, default='auto'
128+
Number of samples per gradient update. Do not specify if `dataset` is provided as a
129+
`keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
130+
epochs : int, default=1
131+
Number of epochs to train the model.
132+
verbose : {"auto", 0, 1, 2}, default="auto"
133+
Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
134+
callbacks : list of keras.callbacks.Callback, optional
135+
List of callbacks to apply during training.
136+
validation_split : float, optional
137+
Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
138+
or tensors).
139+
validation_data : tuple or dataset, optional
140+
Data for validation, overriding `validation_split`.
141+
shuffle : bool, default=True
142+
Whether to shuffle the training data before each epoch (ignored for dataset generators).
143+
initial_epoch : int, default=0
144+
Epoch at which to start training (useful for resuming training).
145+
steps_per_epoch : int or None, optional
146+
Number of steps (batches) before declaring an epoch finished.
147+
validation_steps : int or None, optional
148+
Number of validation steps per validation epoch.
149+
validation_batch_size : int or None, optional
150+
Number of samples per validation batch (defaults to `batch_size`).
151+
validation_freq : int, default=1
152+
Specifies how many training epochs to run before performing validation.
153+
154+
Returns
155+
-------
156+
keras.callbacks.History
157+
A history object containing the training loss and metrics values.
158+
159+
Raises
160+
------
161+
ValueError
162+
If both `dataset` and `simulator` are provided or neither is provided.
163+
"""
114164
return super().fit(*args, **kwargs, adapter=self.adapter)
115165

116166
@classmethod
@@ -139,6 +189,27 @@ def sample(
139189
split: bool = False,
140190
**kwargs,
141191
) -> dict[str, np.ndarray]:
192+
"""
193+
Generates samples from the approximator given input conditions. The `conditions` dictionary is preprocessed
194+
using the `adapter`. Samples are converted to NumPy arrays after inference.
195+
196+
Parameters
197+
----------
198+
num_samples : int
199+
Number of samples to generate.
200+
conditions : dict[str, np.ndarray]
201+
Dictionary of conditioning variables as NumPy arrays.
202+
split : bool, default=False
203+
Whether to split the output arrays along the last axis and return one column vector per target variable
204+
samples.
205+
**kwargs : dict
206+
Additional keyword arguments for the adapter and sampling process.
207+
208+
Returns
209+
-------
210+
dict[str, np.ndarray]
211+
Dictionary containing generated samples with the same keys as `conditions`.
212+
"""
142213
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
143214
# at inference time, inference_variables are estimated by the networks and thus ignored in conditions
144215
conditions.pop("inference_variables", None)
@@ -192,6 +263,22 @@ def _sample(
192263
)
193264

194265
def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray:
266+
"""
267+
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the
268+
`adapter`. Log-probabilities are returned as NumPy arrays.
269+
270+
Parameters
271+
----------
272+
data : dict[str, np.ndarray]
273+
Dictionary of observed data as NumPy arrays.
274+
**kwargs : dict
275+
Additional keyword arguments for the adapter and log-probability computation.
276+
277+
Returns
278+
-------
279+
np.ndarray
280+
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
281+
"""
195282
data = self.adapter(data, strict=False, stage="inference", **kwargs)
196283
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
197284
log_prob = self._log_prob(**data, **kwargs)

0 commit comments

Comments
 (0)