Skip to content

Commit a205c4a

Browse files
Alan ChuAlan Chu
authored andcommitted
second fix and test case
1 parent 90ff8f0 commit a205c4a

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@
2828
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
2929
from weakref import proxy
3030

31-
import torch
32-
from torch.optim import Optimizer
33-
3431
import lightning.pytorch as pl
32+
import torch
3533
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
3634
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
3735
from lightning.fabric.utilities.types import _PATH
@@ -79,6 +77,7 @@
7977
LRSchedulerConfig,
8078
)
8179
from lightning.pytorch.utilities.warnings import PossibleUserWarning
80+
from torch.optim import Optimizer
8281

8382
log = logging.getLogger(__name__)
8483

@@ -940,9 +939,9 @@ def _run(
940939
log.debug(f"{self.__class__.__name__}: preparing data")
941940
self._data_connector.prepare_data()
942941

943-
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
944942
log.debug(f"{self.__class__.__name__}: configuring model")
945943
call._call_configure_model(self)
944+
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
946945

947946
# check if we should delay restoring checkpoint till later
948947
if not self.strategy.restore_checkpoint_after_setup:

tests/tests_pytorch/callbacks/test_finetuning_callback.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
from lightning.pytorch import LightningModule, Trainer, seed_everything
2020
from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
2121
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
22+
from tests_pytorch.helpers.runif import RunIf
2223
from torch import nn
2324
from torch.optim import SGD, Optimizer
2425
from torch.utils.data import DataLoader
2526

26-
from tests_pytorch.helpers.runif import RunIf
27-
2827

2928
class TestBackboneFinetuningCallback(BackboneFinetuning):
3029
def on_train_epoch_start(self, trainer, pl_module):
@@ -283,10 +282,12 @@ def test_complex_nested_model():
283282
directly themselves rather than exclusively their submodules containing parameters."""
284283

285284
model = nn.Sequential(
286-
OrderedDict([
287-
("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))),
288-
("decoder", ConvBlock(128, 10)),
289-
])
285+
OrderedDict(
286+
[
287+
("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))),
288+
("decoder", ConvBlock(128, 10)),
289+
]
290+
)
290291
)
291292

292293
# There are 10 leaf modules or parent modules w/ parameters in the test model
@@ -346,6 +347,8 @@ def test_callbacks_restore(tmp_path):
346347
assert len(callback._internal_optimizer_metadata) == 1
347348

348349
# only 2 param groups
350+
print("##########")
351+
print(callback._internal_optimizer_metadata[0])
349352
assert len(callback._internal_optimizer_metadata[0]) == 2
350353

351354
# original parameters
@@ -431,3 +434,52 @@ def test_unsupported_strategies(tmp_path):
431434
trainer = Trainer(accelerator="cpu", strategy="deepspeed", callbacks=[callback])
432435
with pytest.raises(NotImplementedError, match="does not support running with the DeepSpeed strategy"):
433436
callback.setup(trainer, model, stage=None)
437+
438+
439+
def test_finetuning_with_configure_model(tmp_path):
440+
"""Test that BaseFinetuning works correctly with configure_model by ensuring freeze_before_training
441+
is called after configure_model but before training starts."""
442+
443+
class TrackingFinetuningCallback(BaseFinetuning):
444+
def __init__(self):
445+
super().__init__()
446+
447+
def freeze_before_training(self, pl_module):
448+
assert hasattr(pl_module, "backbone"), "backbone should be configured before freezing"
449+
self.freeze(pl_module.backbone)
450+
451+
def finetune_function(self, pl_module, epoch, optimizer):
452+
pass
453+
454+
class TestModel(LightningModule):
455+
def __init__(self):
456+
super().__init__()
457+
self.configure_model_called_count = 0
458+
459+
def configure_model(self):
460+
self.backbone = nn.Linear(32, 32)
461+
self.classifier = nn.Linear(32, 2)
462+
self.configure_model_called_count += 1
463+
464+
def forward(self, x):
465+
x = self.backbone(x)
466+
return self.classifier(x)
467+
468+
def training_step(self, batch, batch_idx):
469+
return self.forward(batch).sum()
470+
471+
def configure_optimizers(self):
472+
return torch.optim.SGD(self.parameters(), lr=0.1)
473+
474+
print("start of the test")
475+
model = TestModel()
476+
callback = TrackingFinetuningCallback()
477+
trainer = Trainer(
478+
default_root_dir=tmp_path,
479+
callbacks=[callback],
480+
max_epochs=1,
481+
limit_train_batches=1,
482+
)
483+
484+
trainer.fit(model, torch.randn(10, 32))
485+
assert model.configure_model_called_count == 1

0 commit comments

Comments
 (0)