2424from absl import logging
2525from etils import epath
2626from kauldron .checkpoints import checkpointer
27- from kauldron .checkpoints import partial_loader
2827from kauldron .evals import evaluators as evaluators_lib
2928from kauldron .evals import run_strategies
3029from kauldron .train import train_step
@@ -70,8 +69,6 @@ def continuous_eval(
7069 # In eval-only mode, the model weights are restored from the init_transforms
7170 # and not the checkpoint, so we cannot skip it.
7271 state = trainer .init_state (skip_transforms = not trainer .setup .eval_only )
73- # Remove optimizer state to avoid using additional memory.
74- state = state .replace (opt_state = None )
7572 aux = {eval_name : train_step .Auxiliaries () for eval_name in eval_names }
7673
7774 # If preempted, the last checkpoint might be re-computed. There could be
@@ -169,7 +166,6 @@ def _preemptable_iter_new_checkpoints(
169166 return
170167
171168 trainer_ckpt = trainer .checkpointer
172- assert isinstance (trainer_ckpt , checkpointer .Checkpointer )
173169 eval_ckpt = _get_eval_ckpt (trainer_ckpt , eval_names )
174170 # If the eval checkpoint exists, there is an ongoing eval that was preempted
175171 # and we should resume the onging eval.
@@ -179,11 +175,9 @@ def _preemptable_iter_new_checkpoints(
179175 logging .info ('Resume evaluation...' )
180176 # Restore the state from the last eval checkpoint
181177 state = eval_ckpt .restore (state )
182- step = int (state .step )
183178 yield state
184- # state might have been donated, we should not access it after this point.
185179 # Eval is done, remove the duplicated checkpoint
186- eval_ckpt .delete (step )
180+ eval_ckpt .delete (state . step )
187181
188182 for step in trainer_ckpt .iter_new_checkpoints (
189183 min_interval_secs = 10 ,
@@ -195,27 +189,14 @@ def _preemptable_iter_new_checkpoints(
195189 .exists ()
196190 ),
197191 ):
198- # TODO(epot): Rather than `PartialKauldronLoader`, should instead
199- # have some `trainer_ckpt.restore(state, partial_restore=True)`
200- # Only restore the params and step from the trainer checkpoint.
201- state = partial_loader .PartialKauldronLoader (
202- workdir = trainer_ckpt .workdir ,
203- # Load everything except the optimizer state.
204- new_to_old = {
205- f .name : f .name
206- for f in dataclasses .fields (state )
207- if f .name != 'opt_state'
208- },
209- step = step ,
210- ).transform (state )
192+ state = trainer_ckpt .restore (state , step = step )
211193 assert int (state .step ) == step
212194 # Temporarily copy the state to the eval checkpoint, to ensure that
213195 # it won't be deleted by the train job until the current eval is done.
214196 eval_ckpt .save (state , step = step )
215197 yield state
216- # state might have been donated, we should not access it after this point.
217198 # Eval is done, remove the duplicated checkpoint
218- eval_ckpt .delete (step )
199+ eval_ckpt .delete (state . step )
219200
220201
221202def _get_eval_ckpt (
0 commit comments