1717import pytest
1818import torch
1919from torch import optim
20+ from torch .utils .data import DataLoader , TensorDataset
2021
2122from lightning .pytorch import LightningModule , Trainer
2223from lightning .pytorch .callbacks import ModelCheckpoint
2829from lightning .pytorch .demos .boring_classes import BoringDataModule , BoringModel
2930from lightning .pytorch .utilities .exceptions import MisconfigurationException
3031from lightning .pytorch .utilities .types import LRSchedulerConfig
31-
32- from torch import optim
33- from torch .utils .data import DataLoader , TensorDataset
3432from tests_pytorch .helpers .runif import RunIf
3533
3634
@@ -674,7 +672,7 @@ def forward(self, x):
674672
675673 def training_step (self , batch , batch_idx ):
676674 # Add print statement to track batch index and global step
677- if hasattr (self , ' trainer' ):
675+ if hasattr (self , " trainer" ):
678676 print (f"Batch idx: { batch_idx } , Global step: { self .trainer .global_step } " )
679677 return {"loss" : torch .tensor (0.1 , requires_grad = True )}
680678
@@ -721,6 +719,5 @@ def configure_optimizers(self):
721719 # Assert that the scheduler was called the expected number of times
722720 # Allow for a small difference due to environment or rounding discrepancies
723721 assert abs (mocked_sched .call_count - expected_steps ) <= 1 , (
724- f"Scheduler was called { mocked_sched .call_count } times, "
725- f"but expected { expected_steps } calls."
726- )
722+ f"Scheduler was called { mocked_sched .call_count } times, but expected { expected_steps } calls."
723+ )
0 commit comments