Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5fd675c
call configure_module before freeze_before_training
Nov 18, 2024
91775f7
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 18, 2024
9da9e7d
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 18, 2024
faef707
Merge branch 'master' into chualan/fix-19658
lantiga Nov 19, 2024
90ff8f0
remove bad fix
Nov 21, 2024
a205c4a
second fix and test case
Nov 22, 2024
ef35dca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
0e570a8
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 22, 2024
56d05a3
remove print statement
Nov 22, 2024
1c040d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2024
bfa0fd4
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 25, 2024
8ba644a
change assertion order for setup() and configure_model() in test_hook…
Nov 25, 2024
1d8ef66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
11c8be4
Merge branch 'master' into chualan/fix-19658
chualanagit Nov 26, 2024
9e53990
Merge branch 'master' into chualan/fix-19658
lantiga Nov 26, 2024
4567c49
Merge branch 'master' into chualan/fix-19658
lantiga Dec 9, 2024
c6a77a9
Merge branch 'master' into chualan/fix-19658
lantiga Dec 10, 2024
5be022e
Merge branch 'master' into chualan/fix-19658
lantiga Dec 11, 2024
c20c173
Merge branch 'master' into chualan/fix-19658
lantiga Dec 11, 2024
344822b
Merge branch 'master' into chualan/fix-19658
Borda Apr 16, 2025
c7f02ed
Merge branch 'master' into chualan/fix-19658
Borda Apr 16, 2025
3661d79
Merge branch 'master' into chualan/fix-19658
Borda Aug 8, 2025
173ae26
Merge branch 'master' into chualan/fix-19658
Borda Sep 10, 2025
ad3375e
update
deependujha Oct 6, 2025
cfe7a81
update
deependujha Oct 6, 2025
221602f
introduce `on_before_optimizer_setup` hook
deependujha Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,9 +985,9 @@ def _run(
log.debug(f"{self.__class__.__name__}: preparing data")
self._data_connector.prepare_data()

call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
log.debug(f"{self.__class__.__name__}: configuring model")
call._call_configure_model(self)
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment

# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
Expand Down
48 changes: 48 additions & 0 deletions tests/tests_pytorch/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,51 @@ def test_unsupported_strategies(tmp_path):
trainer = Trainer(accelerator="cpu", strategy="deepspeed", callbacks=[callback])
with pytest.raises(NotImplementedError, match="does not support running with the DeepSpeed strategy"):
callback.setup(trainer, model, stage=None)


def test_finetuning_with_configure_model(tmp_path):
"""Test that BaseFinetuning works correctly with configure_model by ensuring freeze_before_training is called after
configure_model but before training starts."""

class TrackingFinetuningCallback(BaseFinetuning):
def __init__(self):
super().__init__()

def freeze_before_training(self, pl_module):
assert hasattr(pl_module, "backbone"), "backbone should be configured before freezing"
self.freeze(pl_module.backbone)

def finetune_function(self, pl_module, epoch, optimizer):
pass

class TestModel(LightningModule):
def __init__(self):
super().__init__()
self.configure_model_called_count = 0

def configure_model(self):
self.backbone = nn.Linear(32, 32)
self.classifier = nn.Linear(32, 2)
self.configure_model_called_count += 1

def forward(self, x):
x = self.backbone(x)
return self.classifier(x)

def training_step(self, batch, batch_idx):
return self.forward(batch).sum()

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)

model = TestModel()
callback = TrackingFinetuningCallback()
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[callback],
max_epochs=1,
limit_train_batches=1,
)

trainer.fit(model, torch.randn(10, 32))
assert model.configure_model_called_count == 1
10 changes: 5 additions & 5 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,11 @@ def training_step(self, batch, batch_idx):
expected = [
{"name": "configure_callbacks"},
{"name": "prepare_data"},
{"name": "configure_model"},
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
{"name": "setup", "kwargs": {"stage": "fit"}},
# DeepSpeed needs the batch size to figure out throughput logging
*([{"name": "train_dataloader"}] if using_deepspeed else []),
{"name": "configure_model"},
{"name": "configure_optimizers"},
{"name": "Callback.on_fit_start", "args": (trainer, model)},
{"name": "on_fit_start"},
Expand Down Expand Up @@ -571,9 +571,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
expected = [
{"name": "configure_callbacks"},
{"name": "prepare_data"},
{"name": "configure_model"},
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
{"name": "setup", "kwargs": {"stage": "fit"}},
{"name": "configure_model"},
{"name": "on_load_checkpoint", "args": (loaded_ckpt,)},
{"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)},
{"name": "Callback.load_state_dict", "args": ({"foo": True},)},
Expand Down Expand Up @@ -651,9 +651,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
expected = [
{"name": "configure_callbacks"},
{"name": "prepare_data"},
{"name": "configure_model"},
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
{"name": "setup", "kwargs": {"stage": "fit"}},
{"name": "configure_model"},
{"name": "on_load_checkpoint", "args": (loaded_ckpt,)},
{"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)},
{"name": "Callback.load_state_dict", "args": ({"foo": True},)},
Expand Down Expand Up @@ -719,9 +719,9 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
expected = [
{"name": "configure_callbacks"},
{"name": "prepare_data"},
{"name": "configure_model"},
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}},
{"name": "setup", "kwargs": {"stage": verb}},
{"name": "configure_model"},
{"name": "zero_grad"},
*(hooks if batches else []),
{"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}},
Expand All @@ -746,9 +746,9 @@ def test_trainer_model_hook_system_predict(tmp_path):
expected = [
{"name": "configure_callbacks"},
{"name": "prepare_data"},
{"name": "configure_model"},
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}},
{"name": "setup", "kwargs": {"stage": "predict"}},
{"name": "configure_model"},
{"name": "zero_grad"},
{"name": "predict_dataloader"},
{"name": "train", "args": (False,)},
Expand Down
Loading