Skip to content

Commit e96c474

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 09bc52b commit e96c474

File tree

1 file changed

+42
-44
lines changed

1 file changed

+42
-44
lines changed

tests/tests_pytorch/trainer/optimization/test_optimizers.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from unittest import mock
15-
from unittest.mock import call
16-
from unittest.mock import patch
17-
15+
from unittest.mock import call, patch
1816

1917
import pytest
2018
import torch
@@ -662,44 +660,44 @@ def lr_scheduler_step(*_): ...
662660

663661
@patch("torch.optim.lr_scheduler.StepLR.step")
664662
def test_lr_scheduler_step_across_epoch_boundaries(mocked_sched, tmp_path):
665-
class StepAcrossEpochsModel(LightningModule):
666-
def __init__(self):
667-
super().__init__()
668-
self.layer = torch.nn.Linear(32, 2)
669-
670-
def forward(self, x):
671-
return self.layer(x)
672-
673-
def training_step(self, batch, batch_idx):
674-
return {"loss": torch.tensor(0.1, requires_grad=True)}
675-
676-
def configure_optimizers(self):
677-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
678-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
679-
return {
680-
"optimizer": optimizer,
681-
"lr_scheduler": {
682-
"scheduler": scheduler,
683-
"interval": "step",
684-
"frequency": 5, # Scheduler steps every 5 iterations
685-
},
686-
}
687-
688-
model = StepAcrossEpochsModel()
689-
690-
# Trainer configuration for cross-epoch testing
691-
trainer = Trainer(
692-
default_root_dir=tmp_path,
693-
limit_train_batches=7, # More than `frequency` iterations per epoch
694-
max_epochs=3, # Test across multiple epochs
695-
)
696-
697-
# Fit the model
698-
trainer.fit(model)
699-
700-
# Calculate the total number of steps (iterations) and expected scheduler calls
701-
total_steps = 7 * 3 # Total iterations (7 batches per epoch * 3 epochs)
702-
expected_steps = total_steps // 5 # Scheduler steps every 5 iterations
703-
704-
# Assert that the scheduler was called the expected number of times
705-
assert mocked_sched.call_count == expected_steps
663+
class StepAcrossEpochsModel(LightningModule):
664+
def __init__(self):
665+
super().__init__()
666+
self.layer = torch.nn.Linear(32, 2)
667+
668+
def forward(self, x):
669+
return self.layer(x)
670+
671+
def training_step(self, batch, batch_idx):
672+
return {"loss": torch.tensor(0.1, requires_grad=True)}
673+
674+
def configure_optimizers(self):
675+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
676+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
677+
return {
678+
"optimizer": optimizer,
679+
"lr_scheduler": {
680+
"scheduler": scheduler,
681+
"interval": "step",
682+
"frequency": 5, # Scheduler steps every 5 iterations
683+
},
684+
}
685+
686+
model = StepAcrossEpochsModel()
687+
688+
# Trainer configuration for cross-epoch testing
689+
trainer = Trainer(
690+
default_root_dir=tmp_path,
691+
limit_train_batches=7, # More than `frequency` iterations per epoch
692+
max_epochs=3, # Test across multiple epochs
693+
)
694+
695+
# Fit the model
696+
trainer.fit(model)
697+
698+
# Calculate the total number of steps (iterations) and expected scheduler calls
699+
total_steps = 7 * 3 # Total iterations (7 batches per epoch * 3 epochs)
700+
expected_steps = total_steps // 5 # Scheduler steps every 5 iterations
701+
702+
# Assert that the scheduler was called the expected number of times
703+
assert mocked_sched.call_count == expected_steps

0 commit comments

Comments
 (0)