|
19 | 19 | from lightning.pytorch import LightningModule, Trainer, seed_everything
|
20 | 20 | from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
|
21 | 21 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
| 22 | +from tests_pytorch.helpers.runif import RunIf |
22 | 23 | from torch import nn
|
23 | 24 | from torch.optim import SGD, Optimizer
|
24 | 25 | from torch.utils.data import DataLoader
|
25 | 26 |
|
26 |
| -from tests_pytorch.helpers.runif import RunIf |
27 |
| - |
28 | 27 |
|
29 | 28 | class TestBackboneFinetuningCallback(BackboneFinetuning):
|
30 | 29 | def on_train_epoch_start(self, trainer, pl_module):
|
@@ -283,10 +282,12 @@ def test_complex_nested_model():
|
283 | 282 | directly themselves rather than exclusively their submodules containing parameters."""
|
284 | 283 |
|
285 | 284 | model = nn.Sequential(
|
286 |
| - OrderedDict([ |
287 |
| - ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
288 |
| - ("decoder", ConvBlock(128, 10)), |
289 |
| - ]) |
| 285 | + OrderedDict( |
| 286 | + [ |
| 287 | + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
| 288 | + ("decoder", ConvBlock(128, 10)), |
| 289 | + ] |
| 290 | + ) |
290 | 291 | )
|
291 | 292 |
|
292 | 293 | # There are 10 leaf modules or parent modules w/ parameters in the test model
|
@@ -346,6 +347,8 @@ def test_callbacks_restore(tmp_path):
|
346 | 347 | assert len(callback._internal_optimizer_metadata) == 1
|
347 | 348 |
|
348 | 349 | # only 2 param groups
|
| 350 | + print("##########") |
| 351 | + print(callback._internal_optimizer_metadata[0]) |
349 | 352 | assert len(callback._internal_optimizer_metadata[0]) == 2
|
350 | 353 |
|
351 | 354 | # original parameters
|
@@ -431,3 +434,52 @@ def test_unsupported_strategies(tmp_path):
|
431 | 434 | trainer = Trainer(accelerator="cpu", strategy="deepspeed", callbacks=[callback])
|
432 | 435 | with pytest.raises(NotImplementedError, match="does not support running with the DeepSpeed strategy"):
|
433 | 436 | callback.setup(trainer, model, stage=None)
|
| 437 | + |
| 438 | + |
| 439 | +def test_finetuning_with_configure_model(tmp_path): |
| 440 | + """Test that BaseFinetuning works correctly with configure_model by ensuring freeze_before_training |
| 441 | + is called after configure_model but before training starts.""" |
| 442 | + |
| 443 | + class TrackingFinetuningCallback(BaseFinetuning): |
| 444 | + def __init__(self): |
| 445 | + super().__init__() |
| 446 | + |
| 447 | + def freeze_before_training(self, pl_module): |
| 448 | + assert hasattr(pl_module, "backbone"), "backbone should be configured before freezing" |
| 449 | + self.freeze(pl_module.backbone) |
| 450 | + |
| 451 | + def finetune_function(self, pl_module, epoch, optimizer): |
| 452 | + pass |
| 453 | + |
| 454 | + class TestModel(LightningModule): |
| 455 | + def __init__(self): |
| 456 | + super().__init__() |
| 457 | + self.configure_model_called_count = 0 |
| 458 | + |
| 459 | + def configure_model(self): |
| 460 | + self.backbone = nn.Linear(32, 32) |
| 461 | + self.classifier = nn.Linear(32, 2) |
| 462 | + self.configure_model_called_count += 1 |
| 463 | + |
| 464 | + def forward(self, x): |
| 465 | + x = self.backbone(x) |
| 466 | + return self.classifier(x) |
| 467 | + |
| 468 | + def training_step(self, batch, batch_idx): |
| 469 | + return self.forward(batch).sum() |
| 470 | + |
| 471 | + def configure_optimizers(self): |
| 472 | + return torch.optim.SGD(self.parameters(), lr=0.1) |
| 473 | + |
| 474 | + print("start of the test") |
| 475 | + model = TestModel() |
| 476 | + callback = TrackingFinetuningCallback() |
| 477 | + trainer = Trainer( |
| 478 | + default_root_dir=tmp_path, |
| 479 | + callbacks=[callback], |
| 480 | + max_epochs=1, |
| 481 | + limit_train_batches=1, |
| 482 | + ) |
| 483 | + |
| 484 | + trainer.fit(model, torch.randn(10, 32)) |
| 485 | + assert model.configure_model_called_count == 1 |
0 commit comments