55
66
77class JAXApproximator (keras .Model ):
8+ """
9+ Base class for approximators using JAX and Keras' stateless training interface.
10+
11+ This class enables stateless training and evaluation steps with JAX, supporting
12+ JAX-compatible gradient computation and variable updates through the `StatelessScope`.
13+
14+ Notes
15+ -----
16+ Subclasses must implement:
17+ - compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]
18+ - _batch_size_from_data(self, data: dict[str, any]) -> int
19+ """
20+
821 # noinspection PyMethodOverriding
922 def compute_metrics (self , * args , ** kwargs ) -> dict [str , jax .Array ]:
10- # implemented by each respective architecture
23+ """
24+ Compute and return a dictionary of metrics for the current batch.
25+
26+ This method is expected to be implemented by each subclass to compute
27+ task-specific metrics using JAX arrays. It is compatible with stateless
28+ execution and must be differentiable under JAX's `grad` system.
29+
30+ Parameters
31+ ----------
32+ *args : tuple
33+ Positional arguments passed to the metric computation function.
34+ **kwargs
35+ Keyword arguments passed to the metric computation function.
36+
37+ Returns
38+ -------
39+ dict of str to jax.Array
40+ Dictionary containing named metric values as JAX arrays.
41+ """
1142 raise NotImplementedError
1243
1344 def stateless_compute_metrics (
@@ -19,17 +50,34 @@ def stateless_compute_metrics(
1950 stage : str = "training" ,
2051 ) -> (jax .Array , tuple ):
2152 """
22- Things we do for jax:
23- 1. Accept trainable variables as the first argument
24- (can be at any position as indicated by the argnum parameter
25- in autograd, but needs to be an explicit arg)
26- 2. Accept, potentially modify, and return other state variables
27- 3. Return just the loss tensor as the first value
28- 4. Return all other values in a tuple as the second value
29-
30- This ensures:
31- 1. The function is stateless
32- 2. The function can be differentiated with jax autograd
53+ Stateless computation of metrics required for JAX autograd.
54+
55+ This method performs a stateless forward pass using the given model
56+ variables and returns both the loss and auxiliary information for
57+ further updates.
58+
59+ Parameters
60+ ----------
61+ trainable_variables : Any
62+ Current values of the trainable weights.
63+ non_trainable_variables : Any
64+ Current values of non-trainable variables (e.g., batch norm statistics).
65+ metrics_variables : Any
66+ Current values of metric tracking variables.
67+ data : dict of str to any
68+ Input data dictionary passed to `compute_metrics`.
69+ stage : str, default="training"
70+ Whether the computation is for "training" or "validation".
71+
72+ Returns
73+ -------
74+ loss : jax.Array
75+ Scalar loss tensor for gradient computation.
76+ aux : tuple
77+ Tuple containing:
78+ - metrics (dict of str to jax.Array)
79+ - updated non-trainable variables
80+ - updated metrics variables
3381 """
3482 state_mapping = []
3583 state_mapping .extend (zip (self .trainable_variables , trainable_variables ))
@@ -48,19 +96,55 @@ def stateless_compute_metrics(
4896 return metrics ["loss" ], (metrics , non_trainable_variables , metrics_variables )
4997
5098 def stateless_test_step (self , state : tuple , data : dict [str , any ]) -> (dict [str , jax .Array ], tuple ):
99+ """
100+ Stateless validation step compatible with JAX.
101+
102+ Parameters
103+ ----------
104+ state : tuple
105+ Tuple of (trainable_variables, non_trainable_variables, metrics_variables).
106+ data : dict of str to any
107+ Input data for validation.
108+
109+ Returns
110+ -------
111+ metrics : dict of str to jax.Array
112+ Dictionary of computed evaluation metrics.
113+ state : tuple
114+ Updated state tuple after evaluation.
115+ """
51116 trainable_variables , non_trainable_variables , metrics_variables = state
52117
53118 loss , aux = self .stateless_compute_metrics (
54119 trainable_variables , non_trainable_variables , metrics_variables , data = data , stage = "validation"
55120 )
56121 metrics , non_trainable_variables , metrics_variables = aux
57122
58- metrics_variables = self ._update_loss (loss , metrics_variables )
123+ metrics_variables = self ._update_metrics (loss , metrics_variables , self . _batch_size_from_data ( data ) )
59124
60125 state = trainable_variables , non_trainable_variables , metrics_variables
61126 return metrics , state
62127
63128 def stateless_train_step (self , state : tuple , data : dict [str , any ]) -> (dict [str , jax .Array ], tuple ):
129+ """
130+ Stateless training step compatible with JAX autograd and stateless optimization.
131+
132+ Computes gradients and applies optimizer updates in a purely functional style.
133+
134+ Parameters
135+ ----------
136+ state : tuple
137+ Tuple of (trainable_variables, non_trainable_variables, optimizer_variables, metrics_variables).
138+ data : dict of str to any
139+ Input data for training.
140+
141+ Returns
142+ -------
143+ metrics : dict of str to jax.Array
144+ Dictionary of computed training metrics.
145+ state : tuple
146+ Updated state tuple after training.
147+ """
64148 trainable_variables , non_trainable_variables , optimizer_variables , metrics_variables = state
65149
66150 grad_fn = jax .value_and_grad (self .stateless_compute_metrics , has_aux = True )
@@ -74,23 +158,92 @@ def stateless_train_step(self, state: tuple, data: dict[str, any]) -> (dict[str,
74158 optimizer_variables , grads , trainable_variables
75159 )
76160
77- metrics_variables = self ._update_loss (loss , metrics_variables )
161+ metrics_variables = self ._update_metrics (loss , metrics_variables , self . _batch_size_from_data ( data ) )
78162
79163 state = trainable_variables , non_trainable_variables , optimizer_variables , metrics_variables
80164 return metrics , state
81165
82166 def test_step (self , * args , ** kwargs ):
167+ """
168+ Alias to `stateless_test_step` for compatibility with `keras.Model`.
169+
170+ Parameters
171+ ----------
172+ *args, **kwargs : Any
173+ Passed through to `stateless_test_step`.
174+
175+ Returns
176+ -------
177+ See `stateless_test_step`.
178+ """
83179 return self .stateless_test_step (* args , ** kwargs )
84180
85181 def train_step (self , * args , ** kwargs ):
182+ """
183+ Alias to `stateless_train_step` for compatibility with `keras.Model`.
184+
185+ Parameters
186+ ----------
187+ *args, **kwargs : Any
188+ Passed through to `stateless_train_step`.
189+
190+ Returns
191+ -------
192+ See `stateless_train_step`.
193+ """
86194 return self .stateless_train_step (* args , ** kwargs )
87195
88- def _update_loss (self , loss : jax .Array , metrics_variables : any ) -> any :
89- # update the loss progress bar, and possibly metrics variables along with it
196+ def _update_metrics (self , loss : jax .Array , metrics_variables : any , sample_weight : any = None ) -> any :
197+ """
198+ Updates metric tracking variables in a stateless JAX-compatible way.
199+
200+ This method updates the loss tracker (and any other Keras metrics)
201+ and returns updated metric variable states for downstream use.
202+
203+ Parameters
204+ ----------
205+ loss : jax.Array
206+ Scalar loss used for metric tracking.
207+ metrics_variables : Any
208+ Current metric variable states.
209+ sample_weight : Any, optional
210+ Sample weights to apply during update.
211+
212+ Returns
213+ -------
214+ metrics_variables : Any
215+ Updated metrics variable states.
216+ """
90217 state_mapping = list (zip (self .metrics_variables , metrics_variables ))
91218 with keras .StatelessScope (state_mapping ) as scope :
92- self ._loss_tracker .update_state (loss )
219+ self ._loss_tracker .update_state (loss , sample_weight = sample_weight )
93220
221+ # JAX is stateless, so we need to return the metrics as state in downstream functions
94222 metrics_variables = [scope .get_current_value (v ) for v in self .metrics_variables ]
95223
96224 return metrics_variables
225+
226+ # noinspection PyMethodOverriding
227+ def _batch_size_from_data (self , data : any ) -> int :
228+ """Obtain the batch size from a batch of data.
229+
230+ To properly weigh the metrics for batches of different sizes, the batch size of a given batch of data is
231+ required. As the data structure differs between approximators, each concrete approximator has to specify
232+ this method.
233+
234+ Parameters
235+ ----------
236+ data :
237+ The data that are passed to `compute_metrics` as keyword arguments.
238+
239+ Returns
240+ -------
241+ batch_size : int
242+ The batch size of the given data.
243+ """
244+ raise NotImplementedError (
245+ "Correct calculation of the metrics requires obtaining the batch size from the supplied data "
246+ "for proper weighting of metrics for batches with different sizes. Please implement the "
247+ "_batch_size_from_data method for your approximator. For a given batch of data, it should "
248+ "return the corresponding batch size."
249+ )
0 commit comments