Skip to content

restore_training_state before on_fit_start? #20338

@lampuiho

Description

@lampuiho

Description & Motivation

I need to move some opimizer states to the device of the corresponding grad of the embeddings
I extended the optimizer to do it after super().load_state_dict but _optimizer_to_device(optimizer, self.root_device) moves them back from cpu to accelerator.
And there is no way to do it in on_fit_start, which was proposed by #8035, for parameters but this doesn't work with optimizers variables because optimizer state loading happens after on_fit_start while parameters loading happens before on_fit_start.

see also #3698

Pitch

Move

        # hook
        if self.state.fn == TrainerFn.FITTING:
            call._call_callback_hooks(self, "on_fit_start")
            call._call_lightning_module_hook(self, "on_fit_start")

After

        # restore optimizers, etc.
        log.debug(f"{self.__class__.__name__}: restoring training state")
        self._checkpoint_connector.restore_training_state()

Alternatives

Can't think of an alernative solution. If someone knows, let me know.

Additional context

No response

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions