Skip to content

Commit 9767f6a

Browse files
committed
Add docs to backend approximator interfaces
1 parent 85d8149 commit 9767f6a

File tree

3 files changed

+301
-17
lines changed

3 files changed

+301
-17
lines changed

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 141 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,40 @@
55

66

77
class JAXApproximator(keras.Model):
8+
"""
9+
Base class for approximators using JAX and Keras' stateless training interface.
10+
11+
This class enables stateless training and evaluation steps with JAX, supporting
12+
JAX-compatible gradient computation and variable updates through the `StatelessScope`.
13+
14+
Notes
15+
-----
16+
Subclasses must implement:
17+
- compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]
18+
- _batch_size_from_data(self, data: dict[str, any]) -> int
19+
"""
20+
821
# noinspection PyMethodOverriding
922
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
10-
# implemented by each respective architecture
23+
"""
24+
Compute and return a dictionary of metrics for the current batch.
25+
26+
This method is expected to be implemented by each subclass to compute
27+
task-specific metrics using JAX arrays. It is compatible with stateless
28+
execution and must be differentiable under JAX's `grad` system.
29+
30+
Parameters
31+
----------
32+
*args : tuple
33+
Positional arguments passed to the metric computation function.
34+
**kwargs : dict
35+
Keyword arguments passed to the metric computation function.
36+
37+
Returns
38+
-------
39+
dict of str to jax.Array
40+
Dictionary containing named metric values as JAX arrays.
41+
"""
1142
raise NotImplementedError
1243

1344
def stateless_compute_metrics(
@@ -19,17 +50,34 @@ def stateless_compute_metrics(
1950
stage: str = "training",
2051
) -> (jax.Array, tuple):
2152
"""
22-
Things we do for jax:
23-
1. Accept trainable variables as the first argument
24-
(can be at any position as indicated by the argnum parameter
25-
in autograd, but needs to be an explicit arg)
26-
2. Accept, potentially modify, and return other state variables
27-
3. Return just the loss tensor as the first value
28-
4. Return all other values in a tuple as the second value
29-
30-
This ensures:
31-
1. The function is stateless
32-
2. The function can be differentiated with jax autograd
53+
Stateless computation of metrics required for JAX autograd.
54+
55+
This method performs a stateless forward pass using the given model
56+
variables and returns both the loss and auxiliary information for
57+
further updates.
58+
59+
Parameters
60+
----------
61+
trainable_variables : Any
62+
Current values of the trainable weights.
63+
non_trainable_variables : Any
64+
Current values of non-trainable variables (e.g., batch norm statistics).
65+
metrics_variables : Any
66+
Current values of metric tracking variables.
67+
data : dict of str to any
68+
Input data dictionary passed to `compute_metrics`.
69+
stage : str, default="training"
70+
Whether the computation is for "training" or "validation".
71+
72+
Returns
73+
-------
74+
loss : jax.Array
75+
Scalar loss tensor for gradient computation.
76+
aux : tuple
77+
Tuple containing:
78+
- metrics (dict of str to jax.Array)
79+
- updated non-trainable variables
80+
- updated metrics variables
3381
"""
3482
state_mapping = []
3583
state_mapping.extend(zip(self.trainable_variables, trainable_variables))
@@ -48,6 +96,23 @@ def stateless_compute_metrics(
4896
return metrics["loss"], (metrics, non_trainable_variables, metrics_variables)
4997

5098
def stateless_test_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
99+
"""
100+
Stateless validation step compatible with JAX.
101+
102+
Parameters
103+
----------
104+
state : tuple
105+
Tuple of (trainable_variables, non_trainable_variables, metrics_variables).
106+
data : dict of str to any
107+
Input data for validation.
108+
109+
Returns
110+
-------
111+
metrics : dict of str to jax.Array
112+
Dictionary of computed evaluation metrics.
113+
state : tuple
114+
Updated state tuple after evaluation.
115+
"""
51116
trainable_variables, non_trainable_variables, metrics_variables = state
52117

53118
loss, aux = self.stateless_compute_metrics(
@@ -61,6 +126,25 @@ def stateless_test_step(self, state: tuple, data: dict[str, any]) -> (dict[str,
61126
return metrics, state
62127

63128
def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str, jax.Array], tuple):
129+
"""
130+
Stateless training step compatible with JAX autograd and stateless optimization.
131+
132+
Computes gradients and applies optimizer updates in a purely functional style.
133+
134+
Parameters
135+
----------
136+
state : tuple
137+
Tuple of (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables).
138+
data : dict of str to any
139+
Input data for training.
140+
141+
Returns
142+
-------
143+
metrics : dict of str to jax.Array
144+
Dictionary of computed training metrics.
145+
state : tuple
146+
Updated state tuple after training.
147+
"""
64148
trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables = state
65149

66150
grad_fn = jax.value_and_grad(self.stateless_compute_metrics, has_aux=True)
@@ -80,17 +164,61 @@ def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str,
80164
return metrics, state
81165

82166
def test_step(self, *args, **kwargs):
167+
"""
168+
Alias to `stateless_test_step` for compatibility with `keras.Model`.
169+
170+
Parameters
171+
----------
172+
*args, **kwargs : Any
173+
Passed through to `stateless_test_step`.
174+
175+
Returns
176+
-------
177+
See `stateless_test_step`.
178+
"""
83179
return self.stateless_test_step(*args, **kwargs)
84180

85181
def train_step(self, *args, **kwargs):
182+
"""
183+
Alias to `stateless_train_step` for compatibility with `keras.Model`.
184+
185+
Parameters
186+
----------
187+
*args, **kwargs : Any
188+
Passed through to `stateless_train_step`.
189+
190+
Returns
191+
-------
192+
See `stateless_train_step`.
193+
"""
86194
return self.stateless_train_step(*args, **kwargs)
87195

88196
def _update_metrics(self, loss: jax.Array, metrics_variables: any, sample_weight: any = None) -> any:
89-
# update the loss progress bar, and possibly metrics variables along with it
197+
"""
198+
Updates metric tracking variables in a stateless JAX-compatible way.
199+
200+
This method updates the loss tracker (and any other Keras metrics)
201+
and returns updated metric variable states for downstream use.
202+
203+
Parameters
204+
----------
205+
loss : jax.Array
206+
Scalar loss used for metric tracking.
207+
metrics_variables : Any
208+
Current metric variable states.
209+
sample_weight : Any, optional
210+
Sample weights to apply during update.
211+
212+
Returns
213+
-------
214+
metrics_variables : Any
215+
Updated metrics variable states.
216+
"""
90217
state_mapping = list(zip(self.metrics_variables, metrics_variables))
91218
with keras.StatelessScope(state_mapping) as scope:
92219
self._loss_tracker.update_state(loss, sample_weight=sample_weight)
93220

221+
# JAX is stateless, so we need to return the metrics as state in downstream functions
94222
metrics_variables = [scope.get_current_value(v) for v in self.metrics_variables]
95223

96224
return metrics_variables

bayesflow/approximators/backend_approximators/tensorflow_approximator.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,83 @@
55

66

77
class TensorFlowApproximator(keras.Model):
8+
"""
9+
Base class for approximators using TensorFlow and Keras training logic.
10+
11+
This class supports training and evaluation loops using TensorFlow backends.
12+
Subclasses are responsible for implementing the `compute_metrics` method and
13+
`_batch_size_from_data`, which extracts batch size information from data inputs.
14+
15+
Notes
16+
-----
17+
Subclasses must implement:
18+
- compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]
19+
- _batch_size_from_data(self, data: dict[str, any]) -> int
20+
"""
21+
822
# noinspection PyMethodOverriding
923
def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
10-
# implemented by each respective architecture
24+
"""
25+
Compute and return a dictionary of metrics for the current batch.
26+
27+
This method is expected to be implemented by each subclass to compute task-specific
28+
metrics (e.g., loss, accuracy). The arguments are dynamically filtered based on the
29+
architecture's metric signature.
30+
31+
Parameters
32+
----------
33+
*args : tuple
34+
Positional arguments passed to the metric computation function.
35+
**kwargs : dict
36+
Keyword arguments passed to the metric computation function.
37+
38+
Returns
39+
-------
40+
dict of str to tf.Tensor
41+
Dictionary containing named metric values as TensorFlow tensors.
42+
"""
1143
raise NotImplementedError
1244

1345
def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
46+
"""
47+
Performs a single validation step.
48+
49+
Filters relevant keyword arguments for metric computation and updates internal
50+
metric trackers using the validation data.
51+
52+
Parameters
53+
----------
54+
data : dict of str to any
55+
Input dictionary containing model inputs and possibly additional information
56+
such as sample_weight or mask.
57+
58+
Returns
59+
-------
60+
dict of str to tf.Tensor
61+
Dictionary of computed validation metrics.
62+
"""
1463
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
1564
metrics = self.compute_metrics(**kwargs)
1665
self._update_metrics(metrics, self._batch_size_from_data(data))
1766
return metrics
1867

1968
def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
69+
"""
70+
Performs a single training step with gradient update.
71+
72+
Computes gradients of the loss with respect to the trainable variables, applies
73+
the update, and updates internal metric trackers.
74+
75+
Parameters
76+
----------
77+
data : dict of str to any
78+
Input dictionary containing model inputs and training targets.
79+
80+
Returns
81+
-------
82+
dict of str to tf.Tensor
83+
Dictionary of computed training metrics.
84+
"""
2085
with tf.GradientTape() as tape:
2186
kwargs = filter_kwargs(data | {"stage": "training"}, self.compute_metrics)
2287
metrics = self.compute_metrics(**kwargs)
@@ -29,7 +94,19 @@ def train_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
2994
self._update_metrics(metrics, self._batch_size_from_data(data))
3095
return metrics
3196

32-
def _update_metrics(self, metrics, sample_weight=None):
97+
def _update_metrics(self, metrics: dict[str, any], sample_weight: tf.Tensor = None):
98+
"""
99+
Updates internal Keras metric objects with the given values.
100+
101+
If a new metric name is encountered, it is added as a new `keras.metrics.Mean` instance.
102+
103+
Parameters
104+
----------
105+
metrics : dict of str to any
106+
Dictionary of computed metric values to update.
107+
sample_weight : tf.Tensor, optional
108+
Sample weights to apply during metric update.
109+
"""
33110
for name, value in metrics.items():
34111
try:
35112
metric_index = self.metrics_names.index(name)

0 commit comments

Comments
 (0)