|
19 | 19 | from lightning.pytorch import LightningModule, Trainer, seed_everything |
20 | 20 | from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint |
21 | 21 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset |
| 22 | +from tests_pytorch.helpers.runif import RunIf |
22 | 23 | from torch import nn |
23 | 24 | from torch.optim import SGD, Optimizer |
24 | 25 | from torch.utils.data import DataLoader |
25 | 26 |
|
26 | | -from tests_pytorch.helpers.runif import RunIf |
27 | | - |
28 | 27 |
|
29 | 28 | class TestBackboneFinetuningCallback(BackboneFinetuning): |
30 | 29 | def on_train_epoch_start(self, trainer, pl_module): |
@@ -283,10 +282,12 @@ def test_complex_nested_model(): |
283 | 282 | directly themselves rather than exclusively their submodules containing parameters.""" |
284 | 283 |
|
285 | 284 | model = nn.Sequential( |
286 | | - OrderedDict([ |
287 | | - ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
288 | | - ("decoder", ConvBlock(128, 10)), |
289 | | - ]) |
| 285 | + OrderedDict( |
| 286 | + [ |
| 287 | + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
| 288 | + ("decoder", ConvBlock(128, 10)), |
| 289 | + ] |
| 290 | + ) |
290 | 291 | ) |
291 | 292 |
|
292 | 293 | # There are 10 leaf modules or parent modules w/ parameters in the test model |
@@ -346,8 +347,6 @@ def test_callbacks_restore(tmp_path): |
346 | 347 | assert len(callback._internal_optimizer_metadata) == 1 |
347 | 348 |
|
348 | 349 | # only 2 param groups |
349 | | - print("##########") |
350 | | - print(callback._internal_optimizer_metadata[0]) |
351 | 350 | assert len(callback._internal_optimizer_metadata[0]) == 2 |
352 | 351 |
|
353 | 352 | # original parameters |
@@ -470,7 +469,6 @@ def training_step(self, batch, batch_idx): |
470 | 469 | def configure_optimizers(self): |
471 | 470 | return torch.optim.SGD(self.parameters(), lr=0.1) |
472 | 471 |
|
473 | | - print("start of the test") |
474 | 472 | model = TestModel() |
475 | 473 | callback = TrackingFinetuningCallback() |
476 | 474 | trainer = Trainer( |
|
0 commit comments