-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Description & Motivation
I need to move some opimizer states to the device of the corresponding grad of the embeddings
I extended the optimizer to do it after super().load_state_dict but _optimizer_to_device(optimizer, self.root_device)
moves them back from cpu to accelerator.
And there is no way to do it in on_fit_start, which was proposed by #8035, for parameters but this doesn't work with optimizers variables because optimizer state loading happens after on_fit_start while parameters loading happens before on_fit_start.
see also #3698
Pitch
Move
# hook
if self.state.fn == TrainerFn.FITTING:
call._call_callback_hooks(self, "on_fit_start")
call._call_lightning_module_hook(self, "on_fit_start")
After
# restore optimizers, etc.
log.debug(f"{self.__class__.__name__}: restoring training state")
self._checkpoint_connector.restore_training_state()
Alternatives
Can't think of an alernative solution. If someone knows, let me know.
Additional context
No response