1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from unittest import mock
15- from unittest .mock import call
16- from unittest .mock import patch
17-
15+ from unittest .mock import call , patch
1816
1917import pytest
2018import torch
@@ -662,44 +660,44 @@ def lr_scheduler_step(*_): ...
662660
663661@patch ("torch.optim.lr_scheduler.StepLR.step" )
664662def test_lr_scheduler_step_across_epoch_boundaries (mocked_sched , tmp_path ):
665- class StepAcrossEpochsModel (LightningModule ):
666- def __init__ (self ):
667- super ().__init__ ()
668- self .layer = torch .nn .Linear (32 , 2 )
669-
670- def forward (self , x ):
671- return self .layer (x )
672-
673- def training_step (self , batch , batch_idx ):
674- return {"loss" : torch .tensor (0.1 , requires_grad = True )}
675-
676- def configure_optimizers (self ):
677- optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
678- scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
679- return {
680- "optimizer" : optimizer ,
681- "lr_scheduler" : {
682- "scheduler" : scheduler ,
683- "interval" : "step" ,
684- "frequency" : 5 , # Scheduler steps every 5 iterations
685- },
686- }
687-
688- model = StepAcrossEpochsModel ()
689-
690- # Trainer configuration for cross-epoch testing
691- trainer = Trainer (
692- default_root_dir = tmp_path ,
693- limit_train_batches = 7 , # More than `frequency` iterations per epoch
694- max_epochs = 3 , # Test across multiple epochs
695- )
696-
697- # Fit the model
698- trainer .fit (model )
699-
700- # Calculate the total number of steps (iterations) and expected scheduler calls
701- total_steps = 7 * 3 # Total iterations (7 batches per epoch * 3 epochs)
702- expected_steps = total_steps // 5 # Scheduler steps every 5 iterations
703-
704- # Assert that the scheduler was called the expected number of times
705- assert mocked_sched .call_count == expected_steps
663+ class StepAcrossEpochsModel (LightningModule ):
664+ def __init__ (self ):
665+ super ().__init__ ()
666+ self .layer = torch .nn .Linear (32 , 2 )
667+
668+ def forward (self , x ):
669+ return self .layer (x )
670+
671+ def training_step (self , batch , batch_idx ):
672+ return {"loss" : torch .tensor (0.1 , requires_grad = True )}
673+
674+ def configure_optimizers (self ):
675+ optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
676+ scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 )
677+ return {
678+ "optimizer" : optimizer ,
679+ "lr_scheduler" : {
680+ "scheduler" : scheduler ,
681+ "interval" : "step" ,
682+ "frequency" : 5 , # Scheduler steps every 5 iterations
683+ },
684+ }
685+
686+ model = StepAcrossEpochsModel ()
687+
688+ # Trainer configuration for cross-epoch testing
689+ trainer = Trainer (
690+ default_root_dir = tmp_path ,
691+ limit_train_batches = 7 , # More than `frequency` iterations per epoch
692+ max_epochs = 3 , # Test across multiple epochs
693+ )
694+
695+ # Fit the model
696+ trainer .fit (model )
697+
698+ # Calculate the total number of steps (iterations) and expected scheduler calls
699+ total_steps = 7 * 3 # Total iterations (7 batches per epoch * 3 epochs)
700+ expected_steps = total_steps // 5 # Scheduler steps every 5 iterations
701+
702+ # Assert that the scheduler was called the expected number of times
703+ assert mocked_sched .call_count == expected_steps
0 commit comments