Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit 307c502

Browse files
Lukasz Kaisercopybara-github
authored andcommitted
Create output dir in training.Loop if it's not there (same as Trainer). Also report training loss in Loop.
PiperOrigin-RevId: 319320458
1 parent ef0d18f commit 307c502

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

trax/fastmath/jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ def _custom_grad(f_vjp, f_original):
418418
'erf': jax_special.erf,
419419
'expit': jax_special.expit,
420420
'grad': jax.grad,
421+
'value_and_grad': jax.value_and_grad,
421422
'jit': jax.jit,
422423
'logsumexp': jax_special.logsumexp,
423424
'lt': lax.lt,

trax/fastmath/ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,26 @@ def grad(*args, **kwargs):
164164
return backend()['grad'](*args, **kwargs)
165165

166166

167+
def value_and_grad(*args, **kwargs):
168+
"""Computes the gradient of the specified function together with the value."""
169+
if 'value_and_grad' in backend():
170+
return backend()['value_and_grad'](*args, **kwargs)
171+
grad_fn = grad(*args, **kwargs)
172+
fn = args[0]
173+
has_aux = False
174+
if has_aux in kwargs:
175+
has_aux = kwargs['has_aux']
176+
if not has_aux:
177+
def val_and_grad(*fn_args, **fn_kwargs):
178+
return fn(*fn_args, **fn_kwargs), grad_fn(*fn_args, **fn_kwargs)
179+
return val_and_grad
180+
def val_and_grad_aux(*fn_args, **fn_kwargs):
181+
g, aux = grad_fn(*fn_args, **fn_kwargs)
182+
res, _ = fn(*fn_args, **fn_kwargs)
183+
return (res, aux), g
184+
return val_and_grad_aux
185+
186+
167187
def vjp(*args, **kwargs):
168188
"""Computes the vector-Jacobian product for the specified function."""
169189
return backend()['vjp'](*args, **kwargs)

trax/supervised/training.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
103103
self._model_in_training = tl.Serial(model, task.loss_layer)
104104
self._eval_model = model if eval_model is None else eval_model
105105
self._eval_task = eval_task
106+
self._rjust_len = max([0] + [len(name) for name in eval_task.metric_names])
107+
106108
self._output_dir = os.path.expanduser(output_dir) if output_dir else None
109+
if output_dir is not None:
110+
tf.io.gfile.makedirs(output_dir)
107111
default_fn = _at_step_1_and_periodically_at(task.n_steps_per_checkpoint)
108112
self._checkpoint_at = checkpoint_at or default_fn
109113
self._eval_at = eval_at or default_fn
@@ -120,9 +124,10 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
120124
_, _ = task.optimizer.tree_init(self._model_in_training.weights)
121125

122126
self._gradients_and_state_fn = (
123-
fastmath.jit(fastmath.grad(self._model_in_training.pure_fn,
124-
argnums=1, # arg1 of pure_fn: weights
125-
has_aux=True))) # return (gradients, state)
127+
fastmath.jit(fastmath.value_and_grad(
128+
self._model_in_training.pure_fn,
129+
argnums=1, # arg1 of pure_fn: weights
130+
has_aux=True))) # return (loss, state), gradients
126131

127132
if eval_task is not None:
128133
model_with_metrics = _model_with_metrics(self._eval_model, eval_task)
@@ -142,13 +147,23 @@ def run(self, n_steps=1):
142147
weights = self._model_in_training.weights
143148
state = self._model_in_training.state
144149
slots = self._task.optimizer.slots
150+
loss_acc, step_acc = 0.0, 0
145151
for _ in range(n_steps):
146152
self._step += 1
147-
weights, state, slots = self._run_one_step(weights, state, slots)
153+
loss, weights, state, slots = self._run_one_step(weights, state, slots)
154+
loss_acc += loss
155+
step_acc += 1
148156
if self._eval_at(self._step):
149157
self._model_in_training.weights = weights
150158
self._model_in_training.state = state
151159
self._eval_model.weights = self._model.weights
160+
# TODO(lukaszkaiser): move this to a better place with other reporting
161+
loss_name = self._task.loss_layer.name
162+
step_acc = max(1, step_acc) # only here do avoid potential divide-by-0
163+
self._log_step('%s %s | % .8f' % (
164+
'train'.ljust(5), loss_name.rjust(self._rjust_len),
165+
loss_acc / float(step_acc)))
166+
loss_acc, step_acc = 0.0, 0
152167
self.run_evals(weights, state)
153168
if self._checkpoint_at(self._step):
154169
self.save_checkpoint(weights, state, slots)
@@ -199,11 +214,11 @@ def _run_one_step(self, weights, state, slots):
199214
opt_params = optimizer._init_opt_params # pylint: disable=protected-access
200215
opt_params.update({'learning_rate': self._task.learning_rate(step)})
201216

202-
gradients, updated_state = (
217+
(loss, updated_state), gradients = (
203218
self._gradients_and_state_fn(batch, weights, state, self.new_rng()))
204219
updated_weights, updated_slots, _ = (
205220
optimizer.tree_update(step, gradients, weights, slots, opt_params))
206-
return updated_weights, updated_state, updated_slots
221+
return loss, updated_weights, updated_state, updated_slots
207222

208223
def run_evals(self, weights=None, state=None):
209224
"""Runs and records evals for this training session.
@@ -230,10 +245,9 @@ def run_evals(self, weights=None, state=None):
230245
self._metrics_fn(batch, metrics_weights, metrics_state, rng))
231246
sums += metric_values
232247
averages = sums / n_batches
233-
rjust_len = max([0] + [len(name) for name in eval_task.metric_names])
234248
for name, average_value in zip(eval_task.metric_names, averages):
235249
self._log_step('%s %s | % .8f' % (
236-
'eval'.ljust(5), name.rjust(rjust_len), average_value))
250+
'eval'.ljust(5), name.rjust(self._rjust_len), average_value))
237251

238252
def _log_step(self, msg):
239253
"""Logs message, labeled with the current training step number."""

0 commit comments

Comments
 (0)