diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 7cdf7888bbfe2..570b378403561 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -443,7 +443,7 @@ def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool) if update_plateau_schedulers ^ config.reduce_on_plateau: continue - current_idx = self.batch_idx if interval == "step" else trainer.current_epoch + current_idx = self.total_batch_idx if interval == "step" else trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 6b88534f3430d..66f5b5c99f9c1 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock -from unittest.mock import call +from unittest.mock import call, patch import pytest import torch from torch import optim +from torch.utils.data import DataLoader, TensorDataset -from lightning.pytorch import Trainer +from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.optimizer import ( _configure_optimizers, @@ -657,3 +658,66 @@ def lr_scheduler_step(*_): ... else: with pytest.raises(MisconfigurationException, match="CustomScheduler` doesn't follow"): _init_optimizers_and_lr_schedulers(model) + + +@patch("torch.optim.lr_scheduler.StepLR.step") +def test_lr_scheduler_step_across_epoch_boundaries(mocked_sched, tmp_path): + class StepAcrossEpochsModel(LightningModule): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + # Add print statement to track batch index and global step + if hasattr(self, "trainer"): + print(f"Batch idx: {batch_idx}, Global step: {self.trainer.global_step}") + return {"loss": torch.tensor(0.1, requires_grad=True)} + + def train_dataloader(self): + x = torch.randn(21, 32) + y = torch.randn(21, 2) + return DataLoader(TensorDataset(x, y), batch_size=3) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + "frequency": 5, # Scheduler steps every 5 iterations + }, + } + + model = StepAcrossEpochsModel() + + # Trainer configuration for cross-epoch testing + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=7, # More than `frequency` iterations per epoch + max_epochs=3, # Test across multiple epochs + ) + + # Fit the model + trainer.fit(model) + + # Debug print statements + print(f"Mocked scheduler step calls: {mocked_sched.call_count}") + print(f"Mocked scheduler call history: {mocked_sched.call_args_list}") + + # Calculate the total number of steps (iterations) and expected scheduler calls + total_steps = 7 * 3 # Total iterations (7 batches per epoch * 3 epochs) + expected_steps = (total_steps - 1) // 5 # Scheduler steps every 5 iterations + + print(f"Total steps: {total_steps}") + print(f"Expected steps: {expected_steps}") + + # Assert that the scheduler was called the expected number of times + # Allow for a small difference due to environment or rounding discrepancies + assert abs(mocked_sched.call_count - expected_steps) <= 1, ( + f"Scheduler was called {mocked_sched.call_count} times, but expected {expected_steps} calls." + )