|
19 | 19 | from lightning.pytorch import LightningModule, Trainer, seed_everything |
20 | 20 | from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint |
21 | 21 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset |
22 | | -from tests_pytorch.helpers.runif import RunIf |
23 | 22 | from torch import nn |
24 | 23 | from torch.optim import SGD, Optimizer |
25 | 24 | from torch.utils.data import DataLoader |
26 | 25 |
|
| 26 | +from tests_pytorch.helpers.runif import RunIf |
| 27 | + |
27 | 28 |
|
28 | 29 | class TestBackboneFinetuningCallback(BackboneFinetuning): |
29 | 30 | def on_train_epoch_start(self, trainer, pl_module): |
@@ -282,12 +283,10 @@ def test_complex_nested_model(): |
282 | 283 | directly themselves rather than exclusively their submodules containing parameters.""" |
283 | 284 |
|
284 | 285 | model = nn.Sequential( |
285 | | - OrderedDict( |
286 | | - [ |
287 | | - ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
288 | | - ("decoder", ConvBlock(128, 10)), |
289 | | - ] |
290 | | - ) |
| 286 | + OrderedDict([ |
| 287 | + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
| 288 | + ("decoder", ConvBlock(128, 10)), |
| 289 | + ]) |
291 | 290 | ) |
292 | 291 |
|
293 | 292 | # There are 10 leaf modules or parent modules w/ parameters in the test model |
|
0 commit comments