Skip to content

Commit ea966d6

Browse files
committed
add testing
1 parent 70c86dc commit ea966d6

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from lightning.pytorch import Trainer, seed_everything
2727
from lightning.pytorch.callbacks import EarlyStopping
28+
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
2829
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
2930
from lightning.pytorch.demos.boring_classes import BoringModel
3031
from lightning.pytorch.tuner.lr_finder import _LRFinder
@@ -800,3 +801,47 @@ def configure_optimizers(self):
800801
assert len(lr_find_checkpoints) == 0, (
801802
f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
802803
)
804+
805+
806+
def test_lr_finder_with_backbone_finetuning_callback(tmp_path):
807+
"""Test that lr_find works correctly with BackboneFinetuning callback."""
808+
809+
class ModelWithBackbone(BoringModel):
810+
def __init__(self):
811+
super().__init__()
812+
# Create a simple backbone-head architecture
813+
self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 16), torch.nn.ReLU(), torch.nn.Linear(16, 8))
814+
self.head = torch.nn.Linear(8, 2)
815+
self.learning_rate = 1e-3
816+
817+
def forward(self, x):
818+
backbone_features = self.backbone(x)
819+
return self.head(backbone_features)
820+
821+
def configure_optimizers(self):
822+
return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
823+
824+
model = ModelWithBackbone()
825+
backbone_finetuning = BackboneFinetuning(unfreeze_backbone_at_epoch=1)
826+
827+
trainer = Trainer(
828+
default_root_dir=tmp_path,
829+
max_epochs=3,
830+
enable_checkpointing=False,
831+
enable_progress_bar=False,
832+
enable_model_summary=False,
833+
logger=False,
834+
callbacks=[backbone_finetuning],
835+
)
836+
837+
tuner = Tuner(trainer)
838+
lr_finder = tuner.lr_find(model, num_training=5)
839+
840+
assert lr_finder is not None
841+
assert hasattr(lr_finder, "results")
842+
assert len(lr_finder.results) > 0
843+
trainer.fit(model)
844+
845+
# Check that backbone was unfrozen at the correct epoch
846+
for param in model.backbone.parameters():
847+
assert param.requires_grad, "Backbone parameters should be unfrozen after epoch 1"

0 commit comments

Comments
 (0)