Skip to content

Commit f29ee07

Browse files
Add tests for time based validation
1 parent e9b66cb commit f29ee07

File tree

1 file changed

+96
-1
lines changed

1 file changed

+96
-1
lines changed

tests/tests_pytorch/trainer/flags/test_val_check_interval.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import logging
1515

1616
import pytest
17+
import time
18+
import re
19+
from unittest.mock import patch
1720
from torch.utils.data import DataLoader
1821

1922
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
@@ -127,9 +130,101 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch():
127130
"""Test that an exception is raised when `val_check_interval` is set to float with
128131
`check_val_every_n_epoch=None`"""
129132
with pytest.raises(
130-
MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`"
133+
MisconfigurationException,
134+
match=re.escape(
135+
"`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136+
"datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137+
)
131138
):
132139
Trainer(
133140
val_check_interval=0.5,
134141
check_val_every_n_epoch=None,
135142
)
143+
144+
def test_time_based_val_check_interval(tmp_path):
145+
call_count = {"count": 0}
146+
def fake_time():
147+
result = call_count["count"]
148+
call_count["count"] += 2
149+
return result
150+
151+
with patch("time.monotonic", side_effect=fake_time):
152+
trainer = Trainer(
153+
default_root_dir=tmp_path,
154+
logger=False,
155+
enable_checkpointing=False,
156+
max_epochs=1,
157+
max_steps=5, # 5 steps: simulate 10s total wall-clock time
158+
limit_val_batches=1,
159+
val_check_interval="00:00:00:02", # every 2s
160+
)
161+
model = BoringModel()
162+
trainer.fit(model)
163+
164+
# Assert 5 validations happened
165+
val_runs = trainer.fit_loop.epoch_loop.val_loop.batch_progress.total.completed
166+
# The number of validation runs should be equal to the number of times we called fake_time
167+
assert val_runs == 5, f"Expected 5 validations, got {val_runs}"
168+
169+
170+
@pytest.mark.parametrize(
171+
"check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description",
172+
[
173+
(None, "00:00:00:04", 2, [0, 1, 0, 1, 0], "val_check_interval timer only, no epoch gating"),
174+
(1, "00:00:00:06", 8, [1, 1, 2, 1, 1], "val_check_interval timer only, no epoch gating"),
175+
(2, "00:00:00:06", 9, [0, 2, 0, 2, 0], "epoch gating, timer longer than epoch"),
176+
(2, "00:00:00:20", 9, [0, 0, 0, 1, 0], "epoch gating, timer much longer"),
177+
(2, "00:00:00:03", 9, [0, 3, 0, 3, 0], "epoch gating, timer shorter than epoch"),
178+
]
179+
)
180+
def test_time_and_epoch_gated_val_check(tmp_path, check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description):
181+
call_count = {"count": 0}
182+
# Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch)
183+
def fake_time():
184+
result = call_count["count"]
185+
call_count["count"] += 1
186+
return result
187+
188+
# Custom model to record when validation happens (on what epoch)
189+
class TestModel(BoringModel):
190+
val_batches = []
191+
val_epoch_calls = 0
192+
193+
def on_train_batch_end(self, *args, **kwargs):
194+
if isinstance(self.trainer.check_val_every_n_epoch,int) and self.trainer.check_val_every_n_epoch > 1 and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
195+
time.monotonic()
196+
197+
def on_train_epoch_end(self, *args, **kwargs):
198+
print(trainer.fit_loop.epoch_loop.val_loop.batch_progress.current.completed)
199+
self.val_batches.append(trainer.fit_loop.epoch_loop.val_loop.batch_progress.total.completed)
200+
201+
def on_validation_epoch_start(self) -> None:
202+
self.val_epoch_calls += 1
203+
204+
max_epochs = 5
205+
max_steps = max_epochs * epoch_duration
206+
limit_train_batches = epoch_duration
207+
208+
trainer_kwargs = dict(
209+
default_root_dir=tmp_path,
210+
logger=False,
211+
enable_checkpointing=False,
212+
max_epochs=max_epochs,
213+
max_steps=max_steps,
214+
limit_val_batches=1,
215+
limit_train_batches=limit_train_batches,
216+
val_check_interval=val_check_interval,
217+
check_val_every_n_epoch=check_val_every_n_epoch
218+
)
219+
220+
with patch("time.monotonic", side_effect=fake_time):
221+
model = TestModel()
222+
trainer = Trainer(**trainer_kwargs)
223+
trainer.fit(model)
224+
225+
# Validate which epochs validation happened
226+
assert model.val_batches == expected_val_batches, (
227+
f"\nFAILED: {description}"
228+
f"\nExpected validation at batches: {expected_val_batches},"
229+
f"\nGot: {model.val_batches, model.val_epoch_calls}\n"
230+
)

0 commit comments

Comments
 (0)