|
19 | 19 | from lightning.pytorch import LightningModule, Trainer, seed_everything
|
20 | 20 | from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
|
21 | 21 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
22 |
| -from tests_pytorch.helpers.runif import RunIf |
23 | 22 | from torch import nn
|
24 | 23 | from torch.optim import SGD, Optimizer
|
25 | 24 | from torch.utils.data import DataLoader
|
26 | 25 |
|
| 26 | +from tests_pytorch.helpers.runif import RunIf |
| 27 | + |
27 | 28 |
|
28 | 29 | class TestBackboneFinetuningCallback(BackboneFinetuning):
|
29 | 30 | def on_train_epoch_start(self, trainer, pl_module):
|
@@ -282,12 +283,10 @@ def test_complex_nested_model():
|
282 | 283 | directly themselves rather than exclusively their submodules containing parameters."""
|
283 | 284 |
|
284 | 285 | model = nn.Sequential(
|
285 |
| - OrderedDict( |
286 |
| - [ |
287 |
| - ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
288 |
| - ("decoder", ConvBlock(128, 10)), |
289 |
| - ] |
290 |
| - ) |
| 286 | + OrderedDict([ |
| 287 | + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), |
| 288 | + ("decoder", ConvBlock(128, 10)), |
| 289 | + ]) |
291 | 290 | )
|
292 | 291 |
|
293 | 292 | # There are 10 leaf modules or parent modules w/ parameters in the test model
|
|
0 commit comments