14
14
import logging
15
15
16
16
import pytest
17
+ import time
18
+ import re
19
+ from unittest .mock import patch
17
20
from torch .utils .data import DataLoader
18
21
19
22
from lightning .pytorch .demos .boring_classes import BoringModel , RandomDataset , RandomIterableDataset
@@ -127,9 +130,101 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch():
127
130
"""Test that an exception is raised when `val_check_interval` is set to float with
128
131
`check_val_every_n_epoch=None`"""
129
132
with pytest .raises (
130
- MisconfigurationException , match = "`val_check_interval` should be an integer when `check_val_every_n_epoch=None`"
133
+ MisconfigurationException ,
134
+ match = re .escape (
135
+ "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136
+ "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137
+ )
131
138
):
132
139
Trainer (
133
140
val_check_interval = 0.5 ,
134
141
check_val_every_n_epoch = None ,
135
142
)
143
+
144
+ def test_time_based_val_check_interval (tmp_path ):
145
+ call_count = {"count" : 0 }
146
+ def fake_time ():
147
+ result = call_count ["count" ]
148
+ call_count ["count" ] += 2
149
+ return result
150
+
151
+ with patch ("time.monotonic" , side_effect = fake_time ):
152
+ trainer = Trainer (
153
+ default_root_dir = tmp_path ,
154
+ logger = False ,
155
+ enable_checkpointing = False ,
156
+ max_epochs = 1 ,
157
+ max_steps = 5 , # 5 steps: simulate 10s total wall-clock time
158
+ limit_val_batches = 1 ,
159
+ val_check_interval = "00:00:00:02" , # every 2s
160
+ )
161
+ model = BoringModel ()
162
+ trainer .fit (model )
163
+
164
+ # Assert 5 validations happened
165
+ val_runs = trainer .fit_loop .epoch_loop .val_loop .batch_progress .total .completed
166
+ # The number of validation runs should be equal to the number of times we called fake_time
167
+ assert val_runs == 5 , f"Expected 5 validations, got { val_runs } "
168
+
169
+
170
+ @pytest .mark .parametrize (
171
+ "check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description" ,
172
+ [
173
+ (None , "00:00:00:04" , 2 , [0 , 1 , 0 , 1 , 0 ], "val_check_interval timer only, no epoch gating" ),
174
+ (1 , "00:00:00:06" , 8 , [1 , 1 , 2 , 1 , 1 ], "val_check_interval timer only, no epoch gating" ),
175
+ (2 , "00:00:00:06" , 9 , [0 , 2 , 0 , 2 , 0 ], "epoch gating, timer longer than epoch" ),
176
+ (2 , "00:00:00:20" , 9 , [0 , 0 , 0 , 1 , 0 ], "epoch gating, timer much longer" ),
177
+ (2 , "00:00:00:03" , 9 , [0 , 3 , 0 , 3 , 0 ], "epoch gating, timer shorter than epoch" ),
178
+ ]
179
+ )
180
+ def test_time_and_epoch_gated_val_check (tmp_path , check_val_every_n_epoch , val_check_interval , epoch_duration , expected_val_batches , description ):
181
+ call_count = {"count" : 0 }
182
+ # Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch)
183
+ def fake_time ():
184
+ result = call_count ["count" ]
185
+ call_count ["count" ] += 1
186
+ return result
187
+
188
+ # Custom model to record when validation happens (on what epoch)
189
+ class TestModel (BoringModel ):
190
+ val_batches = []
191
+ val_epoch_calls = 0
192
+
193
+ def on_train_batch_end (self , * args , ** kwargs ):
194
+ if isinstance (self .trainer .check_val_every_n_epoch ,int ) and self .trainer .check_val_every_n_epoch > 1 and (self .trainer .current_epoch + 1 ) % self .trainer .check_val_every_n_epoch != 0 :
195
+ time .monotonic ()
196
+
197
+ def on_train_epoch_end (self , * args , ** kwargs ):
198
+ print (trainer .fit_loop .epoch_loop .val_loop .batch_progress .current .completed )
199
+ self .val_batches .append (trainer .fit_loop .epoch_loop .val_loop .batch_progress .total .completed )
200
+
201
+ def on_validation_epoch_start (self ) -> None :
202
+ self .val_epoch_calls += 1
203
+
204
+ max_epochs = 5
205
+ max_steps = max_epochs * epoch_duration
206
+ limit_train_batches = epoch_duration
207
+
208
+ trainer_kwargs = dict (
209
+ default_root_dir = tmp_path ,
210
+ logger = False ,
211
+ enable_checkpointing = False ,
212
+ max_epochs = max_epochs ,
213
+ max_steps = max_steps ,
214
+ limit_val_batches = 1 ,
215
+ limit_train_batches = limit_train_batches ,
216
+ val_check_interval = val_check_interval ,
217
+ check_val_every_n_epoch = check_val_every_n_epoch
218
+ )
219
+
220
+ with patch ("time.monotonic" , side_effect = fake_time ):
221
+ model = TestModel ()
222
+ trainer = Trainer (** trainer_kwargs )
223
+ trainer .fit (model )
224
+
225
+ # Validate which epochs validation happened
226
+ assert model .val_batches == expected_val_batches , (
227
+ f"\n FAILED: { description } "
228
+ f"\n Expected validation at batches: { expected_val_batches } ,"
229
+ f"\n Got: { model .val_batches , model .val_epoch_calls } \n "
230
+ )
0 commit comments