1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import logging
15-
16- import pytest
17- import time
1815import re
16+ import time
1917from unittest .mock import patch
18+
19+ import pytest
2020from torch .utils .data import DataLoader
2121
2222from lightning .pytorch .demos .boring_classes import BoringModel , RandomDataset , RandomIterableDataset
@@ -132,17 +132,19 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch():
132132 with pytest .raises (
133133 MisconfigurationException ,
134134 match = re .escape (
135- "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136- "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137- )
135+ "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136+ "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137+ ),
138138 ):
139139 Trainer (
140140 val_check_interval = 0.5 ,
141141 check_val_every_n_epoch = None ,
142142 )
143143
144+
144145def test_time_based_val_check_interval (tmp_path ):
145146 call_count = {"count" : 0 }
147+
146148 def fake_time ():
147149 result = call_count ["count" ]
148150 call_count ["count" ] += 2
@@ -168,17 +170,20 @@ def fake_time():
168170
169171
170172@pytest .mark .parametrize (
171- "check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description" ,
173+ ( "check_val_every_n_epoch" , " val_check_interval" , " epoch_duration" , " expected_val_batches" , " description") ,
172174 [
173175 (None , "00:00:00:04" , 2 , [0 , 1 , 0 , 1 , 0 ], "val_check_interval timer only, no epoch gating" ),
174176 (1 , "00:00:00:06" , 8 , [1 , 1 , 2 , 1 , 1 ], "val_check_interval timer only, no epoch gating" ),
175177 (2 , "00:00:00:06" , 9 , [0 , 2 , 0 , 2 , 0 ], "epoch gating, timer longer than epoch" ),
176178 (2 , "00:00:00:20" , 9 , [0 , 0 , 0 , 1 , 0 ], "epoch gating, timer much longer" ),
177179 (2 , "00:00:00:03" , 9 , [0 , 3 , 0 , 3 , 0 ], "epoch gating, timer shorter than epoch" ),
178- ]
180+ ],
179181)
180- def test_time_and_epoch_gated_val_check (tmp_path , check_val_every_n_epoch , val_check_interval , epoch_duration , expected_val_batches , description ):
182+ def test_time_and_epoch_gated_val_check (
183+ tmp_path , check_val_every_n_epoch , val_check_interval , epoch_duration , expected_val_batches , description
184+ ):
181185 call_count = {"count" : 0 }
186+
182187 # Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch)
183188 def fake_time ():
184189 result = call_count ["count" ]
@@ -191,7 +196,11 @@ class TestModel(BoringModel):
191196 val_epoch_calls = 0
192197
193198 def on_train_batch_end (self , * args , ** kwargs ):
194- if isinstance (self .trainer .check_val_every_n_epoch ,int ) and self .trainer .check_val_every_n_epoch > 1 and (self .trainer .current_epoch + 1 ) % self .trainer .check_val_every_n_epoch != 0 :
199+ if (
200+ isinstance (self .trainer .check_val_every_n_epoch , int )
201+ and self .trainer .check_val_every_n_epoch > 1
202+ and (self .trainer .current_epoch + 1 ) % self .trainer .check_val_every_n_epoch != 0
203+ ):
195204 time .monotonic ()
196205
197206 def on_train_epoch_end (self , * args , ** kwargs ):
@@ -205,17 +214,17 @@ def on_validation_epoch_start(self) -> None:
205214 max_steps = max_epochs * epoch_duration
206215 limit_train_batches = epoch_duration
207216
208- trainer_kwargs = dict (
209- default_root_dir = tmp_path ,
210- logger = False ,
211- enable_checkpointing = False ,
212- max_epochs = max_epochs ,
213- max_steps = max_steps ,
214- limit_val_batches = 1 ,
215- limit_train_batches = limit_train_batches ,
216- val_check_interval = val_check_interval ,
217- check_val_every_n_epoch = check_val_every_n_epoch
218- )
217+ trainer_kwargs = {
218+ " default_root_dir" : tmp_path ,
219+ " logger" : False ,
220+ " enable_checkpointing" : False ,
221+ " max_epochs" : max_epochs ,
222+ " max_steps" : max_steps ,
223+ " limit_val_batches" : 1 ,
224+ " limit_train_batches" : limit_train_batches ,
225+ " val_check_interval" : val_check_interval ,
226+ " check_val_every_n_epoch" : check_val_every_n_epoch ,
227+ }
219228
220229 with patch ("time.monotonic" , side_effect = fake_time ):
221230 model = TestModel ()
@@ -227,4 +236,4 @@ def on_validation_epoch_start(self) -> None:
227236 f"\n FAILED: { description } "
228237 f"\n Expected validation at batches: { expected_val_batches } ,"
229238 f"\n Got: { model .val_batches , model .val_epoch_calls } \n "
230- )
239+ )
0 commit comments