|
25 | 25 |
|
26 | 26 | from lightning.pytorch import Trainer, seed_everything
|
27 | 27 | from lightning.pytorch.callbacks import EarlyStopping
|
| 28 | +from lightning.pytorch.callbacks.finetuning import BackboneFinetuning |
28 | 29 | from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
|
29 | 30 | from lightning.pytorch.demos.boring_classes import BoringModel
|
30 | 31 | from lightning.pytorch.tuner.lr_finder import _LRFinder
|
@@ -800,3 +801,47 @@ def configure_optimizers(self):
|
800 | 801 | assert len(lr_find_checkpoints) == 0, (
|
801 | 802 | f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
|
802 | 803 | )
|
| 804 | + |
| 805 | + |
| 806 | +def test_lr_finder_with_backbone_finetuning_callback(tmp_path): |
| 807 | + """Test that lr_find works correctly with BackboneFinetuning callback.""" |
| 808 | + |
| 809 | + class ModelWithBackbone(BoringModel): |
| 810 | + def __init__(self): |
| 811 | + super().__init__() |
| 812 | + # Create a simple backbone-head architecture |
| 813 | + self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 16), torch.nn.ReLU(), torch.nn.Linear(16, 8)) |
| 814 | + self.head = torch.nn.Linear(8, 2) |
| 815 | + self.learning_rate = 1e-3 |
| 816 | + |
| 817 | + def forward(self, x): |
| 818 | + backbone_features = self.backbone(x) |
| 819 | + return self.head(backbone_features) |
| 820 | + |
| 821 | + def configure_optimizers(self): |
| 822 | + return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) |
| 823 | + |
| 824 | + model = ModelWithBackbone() |
| 825 | + backbone_finetuning = BackboneFinetuning(unfreeze_backbone_at_epoch=1) |
| 826 | + |
| 827 | + trainer = Trainer( |
| 828 | + default_root_dir=tmp_path, |
| 829 | + max_epochs=3, |
| 830 | + enable_checkpointing=False, |
| 831 | + enable_progress_bar=False, |
| 832 | + enable_model_summary=False, |
| 833 | + logger=False, |
| 834 | + callbacks=[backbone_finetuning], |
| 835 | + ) |
| 836 | + |
| 837 | + tuner = Tuner(trainer) |
| 838 | + lr_finder = tuner.lr_find(model, num_training=5) |
| 839 | + |
| 840 | + assert lr_finder is not None |
| 841 | + assert hasattr(lr_finder, "results") |
| 842 | + assert len(lr_finder.results) > 0 |
| 843 | + trainer.fit(model) |
| 844 | + |
| 845 | + # Check that backbone was unfrozen at the correct epoch |
| 846 | + for param in model.backbone.parameters(): |
| 847 | + assert param.requires_grad, "Backbone parameters should be unfrozen after epoch 1" |
0 commit comments