12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import logging
15
-
16
- import pytest
17
- import time
18
15
import re
16
+ import time
19
17
from unittest .mock import patch
18
+
19
+ import pytest
20
20
from torch .utils .data import DataLoader
21
21
22
22
from lightning .pytorch .demos .boring_classes import BoringModel , RandomDataset , RandomIterableDataset
@@ -132,17 +132,19 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch():
132
132
with pytest .raises (
133
133
MisconfigurationException ,
134
134
match = re .escape (
135
- "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136
- "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137
- )
135
+ "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136
+ "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137
+ ),
138
138
):
139
139
Trainer (
140
140
val_check_interval = 0.5 ,
141
141
check_val_every_n_epoch = None ,
142
142
)
143
143
144
+
144
145
def test_time_based_val_check_interval (tmp_path ):
145
146
call_count = {"count" : 0 }
147
+
146
148
def fake_time ():
147
149
result = call_count ["count" ]
148
150
call_count ["count" ] += 2
@@ -168,17 +170,20 @@ def fake_time():
168
170
169
171
170
172
@pytest .mark .parametrize (
171
- "check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description" ,
173
+ ( "check_val_every_n_epoch" , " val_check_interval" , " epoch_duration" , " expected_val_batches" , " description") ,
172
174
[
173
175
(None , "00:00:00:04" , 2 , [0 , 1 , 0 , 1 , 0 ], "val_check_interval timer only, no epoch gating" ),
174
176
(1 , "00:00:00:06" , 8 , [1 , 1 , 2 , 1 , 1 ], "val_check_interval timer only, no epoch gating" ),
175
177
(2 , "00:00:00:06" , 9 , [0 , 2 , 0 , 2 , 0 ], "epoch gating, timer longer than epoch" ),
176
178
(2 , "00:00:00:20" , 9 , [0 , 0 , 0 , 1 , 0 ], "epoch gating, timer much longer" ),
177
179
(2 , "00:00:00:03" , 9 , [0 , 3 , 0 , 3 , 0 ], "epoch gating, timer shorter than epoch" ),
178
- ]
180
+ ],
179
181
)
180
- def test_time_and_epoch_gated_val_check (tmp_path , check_val_every_n_epoch , val_check_interval , epoch_duration , expected_val_batches , description ):
182
+ def test_time_and_epoch_gated_val_check (
183
+ tmp_path , check_val_every_n_epoch , val_check_interval , epoch_duration , expected_val_batches , description
184
+ ):
181
185
call_count = {"count" : 0 }
186
+
182
187
# Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch)
183
188
def fake_time ():
184
189
result = call_count ["count" ]
@@ -191,7 +196,11 @@ class TestModel(BoringModel):
191
196
val_epoch_calls = 0
192
197
193
198
def on_train_batch_end (self , * args , ** kwargs ):
194
- if isinstance (self .trainer .check_val_every_n_epoch ,int ) and self .trainer .check_val_every_n_epoch > 1 and (self .trainer .current_epoch + 1 ) % self .trainer .check_val_every_n_epoch != 0 :
199
+ if (
200
+ isinstance (self .trainer .check_val_every_n_epoch , int )
201
+ and self .trainer .check_val_every_n_epoch > 1
202
+ and (self .trainer .current_epoch + 1 ) % self .trainer .check_val_every_n_epoch != 0
203
+ ):
195
204
time .monotonic ()
196
205
197
206
def on_train_epoch_end (self , * args , ** kwargs ):
@@ -205,17 +214,17 @@ def on_validation_epoch_start(self) -> None:
205
214
max_steps = max_epochs * epoch_duration
206
215
limit_train_batches = epoch_duration
207
216
208
- trainer_kwargs = dict (
209
- default_root_dir = tmp_path ,
210
- logger = False ,
211
- enable_checkpointing = False ,
212
- max_epochs = max_epochs ,
213
- max_steps = max_steps ,
214
- limit_val_batches = 1 ,
215
- limit_train_batches = limit_train_batches ,
216
- val_check_interval = val_check_interval ,
217
- check_val_every_n_epoch = check_val_every_n_epoch
218
- )
217
+ trainer_kwargs = {
218
+ " default_root_dir" : tmp_path ,
219
+ " logger" : False ,
220
+ " enable_checkpointing" : False ,
221
+ " max_epochs" : max_epochs ,
222
+ " max_steps" : max_steps ,
223
+ " limit_val_batches" : 1 ,
224
+ " limit_train_batches" : limit_train_batches ,
225
+ " val_check_interval" : val_check_interval ,
226
+ " check_val_every_n_epoch" : check_val_every_n_epoch ,
227
+ }
219
228
220
229
with patch ("time.monotonic" , side_effect = fake_time ):
221
230
model = TestModel ()
@@ -227,4 +236,4 @@ def on_validation_epoch_start(self) -> None:
227
236
f"\n FAILED: { description } "
228
237
f"\n Expected validation at batches: { expected_val_batches } ,"
229
238
f"\n Got: { model .val_batches , model .val_epoch_calls } \n "
230
- )
239
+ )
0 commit comments