diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 26af335f7be93..e0de8a24b38f5 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -389,6 +389,7 @@ def __init__( self._add_instantiators() self.before_instantiate_classes() self.instantiate_classes() + self.after_instantiate_classes() if self.subcommand is not None: self._run_subcommand(self.subcommand) @@ -560,6 +561,9 @@ def instantiate_classes(self) -> None: self._add_configure_optimizers_method_to_model(self.subcommand) self.trainer = self.instantiate_trainer() + def after_instantiate_classes(self) -> None: + """Implement to run some code after instantiating the classes.""" + def instantiate_trainer(self, **kwargs: Any) -> Trainer: """Instantiates the trainer.