Skip to content

Commit ad3375e

Browse files
committed
update
1 parent 173ae26 commit ad3375e

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

src/lightning/pytorch/callbacks/finetuning.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
106106

107107
@override
108108
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
109+
# freeze the required modules before training
110+
self.freeze_before_training(pl_module)
111+
112+
from lightning.pytorch.strategies import DeepSpeedStrategy
113+
114+
if isinstance(trainer.strategy, DeepSpeedStrategy):
115+
raise NotImplementedError(
116+
"The Finetuning callback does not support running with the DeepSpeed strategy."
117+
" Choose a different strategy or disable the callback."
118+
)
119+
109120
# restore the param_groups created during the previous training.
110121
if self._restarting:
111122
named_parameters = dict(pl_module.named_parameters())
@@ -273,18 +284,6 @@ def unfreeze_and_add_param_group(
273284
if params:
274285
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
275286

276-
@override
277-
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
278-
self.freeze_before_training(pl_module)
279-
280-
from lightning.pytorch.strategies import DeepSpeedStrategy
281-
282-
if isinstance(trainer.strategy, DeepSpeedStrategy):
283-
raise NotImplementedError(
284-
"The Finetuning callback does not support running with the DeepSpeed strategy."
285-
" Choose a different strategy or disable the callback."
286-
)
287-
288287
@staticmethod
289288
def _apply_mapping_to_param_groups(param_groups: list[dict[str, Any]], mapping: dict) -> list[dict[str, Any]]:
290289
output = []

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,9 +985,9 @@ def _run(
985985
log.debug(f"{self.__class__.__name__}: preparing data")
986986
self._data_connector.prepare_data()
987987

988+
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
988989
log.debug(f"{self.__class__.__name__}: configuring model")
989990
call._call_configure_model(self)
990-
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
991991

992992
# check if we should delay restoring checkpoint till later
993993
if not self.strategy.restore_checkpoint_after_setup:

tests/tests_pytorch/models/test_hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,11 +472,11 @@ def training_step(self, batch, batch_idx):
472472
expected = [
473473
{"name": "configure_callbacks"},
474474
{"name": "prepare_data"},
475-
{"name": "configure_model"},
476475
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
477476
{"name": "setup", "kwargs": {"stage": "fit"}},
478477
# DeepSpeed needs the batch size to figure out throughput logging
479478
*([{"name": "train_dataloader"}] if using_deepspeed else []),
479+
{"name": "configure_model"},
480480
{"name": "configure_optimizers"},
481481
{"name": "Callback.on_fit_start", "args": (trainer, model)},
482482
{"name": "on_fit_start"},
@@ -651,9 +651,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
651651
expected = [
652652
{"name": "configure_callbacks"},
653653
{"name": "prepare_data"},
654-
{"name": "configure_model"},
655654
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
656655
{"name": "setup", "kwargs": {"stage": "fit"}},
656+
{"name": "configure_model"},
657657
{"name": "on_load_checkpoint", "args": (loaded_ckpt,)},
658658
{"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)},
659659
{"name": "Callback.load_state_dict", "args": ({"foo": True},)},
@@ -719,9 +719,9 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
719719
expected = [
720720
{"name": "configure_callbacks"},
721721
{"name": "prepare_data"},
722-
{"name": "configure_model"},
723722
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}},
724723
{"name": "setup", "kwargs": {"stage": verb}},
724+
{"name": "configure_model"},
725725
{"name": "zero_grad"},
726726
*(hooks if batches else []),
727727
{"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}},
@@ -746,9 +746,9 @@ def test_trainer_model_hook_system_predict(tmp_path):
746746
expected = [
747747
{"name": "configure_callbacks"},
748748
{"name": "prepare_data"},
749-
{"name": "configure_model"},
750749
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}},
751750
{"name": "setup", "kwargs": {"stage": "predict"}},
751+
{"name": "configure_model"},
752752
{"name": "zero_grad"},
753753
{"name": "predict_dataloader"},
754754
{"name": "train", "args": (False,)},

0 commit comments

Comments
 (0)