Skip to content

Commit 3bcf564

Browse files
author
The kauldron Authors
committed
Avoid loading opt_state and collections in evaluators.
PiperOrigin-RevId: 672089006
1 parent d907d3e commit 3bcf564

File tree

1 file changed

+3
-22
lines changed

1 file changed

+3
-22
lines changed

kauldron/evals/eval_impl.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from absl import logging
2525
from etils import epath
2626
from kauldron.checkpoints import checkpointer
27-
from kauldron.checkpoints import partial_loader
2827
from kauldron.evals import evaluators as evaluators_lib
2928
from kauldron.evals import run_strategies
3029
from kauldron.train import train_step
@@ -70,8 +69,6 @@ def continuous_eval(
7069
# In eval-only mode, the model weights are restored from the init_transforms
7170
# and not the checkpoint, so we cannot skip it.
7271
state = trainer.init_state(skip_transforms=not trainer.setup.eval_only)
73-
# Remove optimizer state to avoid using additional memory.
74-
state = state.replace(opt_state=None)
7572
aux = {eval_name: train_step.Auxiliaries() for eval_name in eval_names}
7673

7774
# If preempted, the last checkpoint might be re-computed. There could be
@@ -169,7 +166,6 @@ def _preemptable_iter_new_checkpoints(
169166
return
170167

171168
trainer_ckpt = trainer.checkpointer
172-
assert isinstance(trainer_ckpt, checkpointer.Checkpointer)
173169
eval_ckpt = _get_eval_ckpt(trainer_ckpt, eval_names)
174170
# If the eval checkpoint exists, there is an ongoing eval that was preempted
175171
# and we should resume the onging eval.
@@ -179,11 +175,9 @@ def _preemptable_iter_new_checkpoints(
179175
logging.info('Resume evaluation...')
180176
# Restore the state from the last eval checkpoint
181177
state = eval_ckpt.restore(state)
182-
step = int(state.step)
183178
yield state
184-
# state might have been donated, we should not access it after this point.
185179
# Eval is done, remove the duplicated checkpoint
186-
eval_ckpt.delete(step)
180+
eval_ckpt.delete(state.step)
187181

188182
for step in trainer_ckpt.iter_new_checkpoints(
189183
min_interval_secs=10,
@@ -195,27 +189,14 @@ def _preemptable_iter_new_checkpoints(
195189
.exists()
196190
),
197191
):
198-
# TODO(epot): Rather than `PartialKauldronLoader`, should instead
199-
# have some `trainer_ckpt.restore(state, partial_restore=True)`
200-
# Only restore the params and step from the trainer checkpoint.
201-
state = partial_loader.PartialKauldronLoader(
202-
workdir=trainer_ckpt.workdir,
203-
# Load everything except the optimizer state.
204-
new_to_old={
205-
f.name: f.name
206-
for f in dataclasses.fields(state)
207-
if f.name != 'opt_state'
208-
},
209-
step=step,
210-
).transform(state)
192+
state = trainer_ckpt.restore(state, step=step)
211193
assert int(state.step) == step
212194
# Temporarily copy the state to the eval checkpoint, to ensure that
213195
# it won't be deleted by the train job until the current eval is done.
214196
eval_ckpt.save(state, step=step)
215197
yield state
216-
# state might have been donated, we should not access it after this point.
217198
# Eval is done, remove the duplicated checkpoint
218-
eval_ckpt.delete(step)
199+
eval_ckpt.delete(state.step)
219200

220201

221202
def _get_eval_ckpt(

0 commit comments

Comments
 (0)