|
19 | 19 | from lightning.pytorch import LightningModule, Trainer, seed_everything
|
20 | 20 | from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
|
21 | 21 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
| 22 | +from tests_pytorch.helpers.runif import RunIf |
22 | 23 | from torch import nn
|
23 | 24 | from torch.optim import SGD, Optimizer
|
24 | 25 | from torch.utils.data import DataLoader
|
25 | 26 |
|
26 |
| -from tests_pytorch.helpers.runif import RunIf |
27 |
| - |
28 | 27 |
|
29 | 28 | class TestBackboneFinetuningCallback(BackboneFinetuning):
|
30 | 29 | def on_train_epoch_start(self, trainer, pl_module):
|
@@ -283,10 +282,12 @@ def test_complex_nested_model():
|
283 | 282 | directly themselves rather than exclusively their submodules containing parameters."""
|
284 | 283 |
|
285 | 284 | model = nn.Sequential(
|
286 |
| - OrderedDict([ |
287 |
| - ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
288 |
| - ("decoder", ConvBlock(128, 10)), |
289 |
| - ]) |
| 285 | + OrderedDict( |
| 286 | + [ |
| 287 | + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
| 288 | + ("decoder", ConvBlock(128, 10)), |
| 289 | + ] |
| 290 | + ) |
290 | 291 | )
|
291 | 292 |
|
292 | 293 | # There are 10 leaf modules or parent modules w/ parameters in the test model
|
@@ -346,8 +347,6 @@ def test_callbacks_restore(tmp_path):
|
346 | 347 | assert len(callback._internal_optimizer_metadata) == 1
|
347 | 348 |
|
348 | 349 | # only 2 param groups
|
349 |
| - print("##########") |
350 |
| - print(callback._internal_optimizer_metadata[0]) |
351 | 350 | assert len(callback._internal_optimizer_metadata[0]) == 2
|
352 | 351 |
|
353 | 352 | # original parameters
|
@@ -470,7 +469,6 @@ def training_step(self, batch, batch_idx):
|
470 | 469 | def configure_optimizers(self):
|
471 | 470 | return torch.optim.SGD(self.parameters(), lr=0.1)
|
472 | 471 |
|
473 |
| - print("start of the test") |
474 | 472 | model = TestModel()
|
475 | 473 | callback = TrackingFinetuningCallback()
|
476 | 474 | trainer = Trainer(
|
|
0 commit comments