Skip to content

Commit d907d3e

Browse files
author
The kauldron Authors
committed
Avoid loading opt_state and collections in evaluators.
PiperOrigin-RevId: 671764183
1 parent 59b3fc8 commit d907d3e

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

kauldron/evals/eval_impl.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from absl import logging
2525
from etils import epath
2626
from kauldron.checkpoints import checkpointer
27+
from kauldron.checkpoints import partial_loader
2728
from kauldron.evals import evaluators as evaluators_lib
2829
from kauldron.evals import run_strategies
2930
from kauldron.train import train_step
@@ -69,6 +70,8 @@ def continuous_eval(
6970
# In eval-only mode, the model weights are restored from the init_transforms
7071
# and not the checkpoint, so we cannot skip it.
7172
state = trainer.init_state(skip_transforms=not trainer.setup.eval_only)
73+
# Remove optimizer state to avoid using additional memory.
74+
state = state.replace(opt_state=None)
7275
aux = {eval_name: train_step.Auxiliaries() for eval_name in eval_names}
7376

7477
# If preempted, the last checkpoint might be re-computed. There could be
@@ -166,6 +169,7 @@ def _preemptable_iter_new_checkpoints(
166169
return
167170

168171
trainer_ckpt = trainer.checkpointer
172+
assert isinstance(trainer_ckpt, checkpointer.Checkpointer)
169173
eval_ckpt = _get_eval_ckpt(trainer_ckpt, eval_names)
170174
# If the eval checkpoint exists, there is an ongoing eval that was preempted
171175
# and we should resume the onging eval.
@@ -175,9 +179,11 @@ def _preemptable_iter_new_checkpoints(
175179
logging.info('Resume evaluation...')
176180
# Restore the state from the last eval checkpoint
177181
state = eval_ckpt.restore(state)
182+
step = int(state.step)
178183
yield state
184+
# state might have been donated, we should not access it after this point.
179185
# Eval is done, remove the duplicated checkpoint
180-
eval_ckpt.delete(state.step)
186+
eval_ckpt.delete(step)
181187

182188
for step in trainer_ckpt.iter_new_checkpoints(
183189
min_interval_secs=10,
@@ -189,14 +195,27 @@ def _preemptable_iter_new_checkpoints(
189195
.exists()
190196
),
191197
):
192-
state = trainer_ckpt.restore(state, step=step)
198+
# TODO(epot): Rather than `PartialKauldronLoader`, should instead
199+
# have some `trainer_ckpt.restore(state, partial_restore=True)`
200+
# Only restore the params and step from the trainer checkpoint.
201+
state = partial_loader.PartialKauldronLoader(
202+
workdir=trainer_ckpt.workdir,
203+
# Load everything except the optimizer state.
204+
new_to_old={
205+
f.name: f.name
206+
for f in dataclasses.fields(state)
207+
if f.name != 'opt_state'
208+
},
209+
step=step,
210+
).transform(state)
193211
assert int(state.step) == step
194212
# Temporarily copy the state to the eval checkpoint, to ensure that
195213
# it won't be deleted by the train job until the current eval is done.
196214
eval_ckpt.save(state, step=step)
197215
yield state
216+
# state might have been donated, we should not access it after this point.
198217
# Eval is done, remove the duplicated checkpoint
199-
eval_ckpt.delete(state.step)
218+
eval_ckpt.delete(step)
200219

201220

202221
def _get_eval_ckpt(

0 commit comments

Comments
 (0)