2424from absl import logging
2525from etils import epath
2626from kauldron .checkpoints import checkpointer
27+ from kauldron .checkpoints import partial_loader
2728from kauldron .evals import evaluators as evaluators_lib
2829from kauldron .evals import run_strategies
2930from kauldron .train import train_step
@@ -69,6 +70,8 @@ def continuous_eval(
6970 # In eval-only mode, the model weights are restored from the init_transforms
7071 # and not the checkpoint, so we cannot skip it.
7172 state = trainer .init_state (skip_transforms = not trainer .setup .eval_only )
73+ # Remove optimizer state to avoid using additional memory.
74+ state = state .replace (opt_state = None )
7275 aux = {eval_name : train_step .Auxiliaries () for eval_name in eval_names }
7376
7477 # If preempted, the last checkpoint might be re-computed. There could be
@@ -166,6 +169,7 @@ def _preemptable_iter_new_checkpoints(
166169 return
167170
168171 trainer_ckpt = trainer .checkpointer
172+ assert isinstance (trainer_ckpt , checkpointer .Checkpointer )
169173 eval_ckpt = _get_eval_ckpt (trainer_ckpt , eval_names )
170174 # If the eval checkpoint exists, there is an ongoing eval that was preempted
171175 # and we should resume the onging eval.
@@ -175,9 +179,11 @@ def _preemptable_iter_new_checkpoints(
175179 logging .info ('Resume evaluation...' )
176180 # Restore the state from the last eval checkpoint
177181 state = eval_ckpt .restore (state )
182+ step = int (state .step )
178183 yield state
184+ # state might have been donated, we should not access it after this point.
179185 # Eval is done, remove the duplicated checkpoint
180- eval_ckpt .delete (state . step )
186+ eval_ckpt .delete (step )
181187
182188 for step in trainer_ckpt .iter_new_checkpoints (
183189 min_interval_secs = 10 ,
@@ -189,14 +195,27 @@ def _preemptable_iter_new_checkpoints(
189195 .exists ()
190196 ),
191197 ):
192- state = trainer_ckpt .restore (state , step = step )
198+ # TODO(epot): Rather than `PartialKauldronLoader`, should instead
199+ # have some `trainer_ckpt.restore(state, partial_restore=True)`
200+ # Only restore the params and step from the trainer checkpoint.
201+ state = partial_loader .PartialKauldronLoader (
202+ workdir = trainer_ckpt .workdir ,
203+ # Load everything except the optimizer state.
204+ new_to_old = {
205+ f .name : f .name
206+ for f in dataclasses .fields (state )
207+ if f .name != 'opt_state'
208+ },
209+ step = step ,
210+ ).transform (state )
193211 assert int (state .step ) == step
194212 # Temporarily copy the state to the eval checkpoint, to ensure that
195213 # it won't be deleted by the train job until the current eval is done.
196214 eval_ckpt .save (state , step = step )
197215 yield state
216+ # state might have been donated, we should not access it after this point.
198217 # Eval is done, remove the duplicated checkpoint
199- eval_ckpt .delete (state . step )
218+ eval_ckpt .delete (step )
200219
201220
202221def _get_eval_ckpt (
0 commit comments