Skip to content

Commit 98f608f

Browse files
Add wall-clock val_check_interval with epoch alignment and timer reset
1 parent 5a2b678 commit 98f608f

File tree

6 files changed

+76
-10
lines changed

6 files changed

+76
-10
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections.abc import Iterable, Iterator
2020
from dataclasses import dataclass
2121
from typing import Any, Optional, Union
22+
import time
2223

2324
from lightning_utilities.core.apply_func import apply_to_collection
2425
from torch import Tensor
@@ -313,6 +314,9 @@ def on_run_end(self) -> list[_OUT_DICT]:
313314

314315
if self.verbose and self.trainer.is_global_zero:
315316
self._print_results(logged_outputs, self._stage.value)
317+
318+
now = time.monotonic()
319+
self.trainer._last_val_time = now
316320

317321
return logged_outputs
318322

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
from dataclasses import dataclass
1616
from typing import Any, Optional, Union
17+
import time
1718

1819
import torch
1920
from typing_extensions import override
@@ -283,7 +284,13 @@ def setup_data(self) -> None:
283284
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
284285
self._last_train_dl_reload_epoch = trainer.current_epoch
285286

286-
if isinstance(trainer.val_check_interval, int):
287+
# If time-based validation is enabled, disable batch-based scheduling here.
288+
# Use None to clearly signal "no batch-based validation"; wall-time logic will run elsewhere.
289+
if getattr(trainer, "_val_check_time_interval", None) is not None:
290+
trainer.val_check_batch = None
291+
trainer._train_start_time = time.monotonic()
292+
trainer._last_val_time = trainer._train_start_time
293+
elif isinstance(trainer.val_check_interval, int):
287294
trainer.val_check_batch = trainer.val_check_interval
288295
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
289296
raise ValueError(
@@ -299,7 +306,7 @@ def setup_data(self) -> None:
299306
else:
300307
raise MisconfigurationException(
301308
"When using an IterableDataset for `train_dataloader`,"
302-
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
309+
" `Trainer(val_check_interval)` must be time based, `1.0` or an int. An int k specifies"
303310
" checking validation every k training batches."
304311
)
305312
else:

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import OrderedDict
1717
from dataclasses import dataclass
1818
from typing import Any, Optional, Union
19+
import time
1920

2021
import torch
2122
from typing_extensions import override
@@ -534,6 +535,13 @@ def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool:
534535
# and when the loop allows to stop (min_epochs/steps met)
535536
return True
536537

538+
interval = self.trainer._val_check_time_interval
539+
if interval is not None:
540+
now = time.monotonic()
541+
if now - self.trainer._last_val_time >= interval:
542+
# time’s up → tell Trainer to validate
543+
return True
544+
return False
537545
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
538546
is_val_check_batch = is_last_batch
539547
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from collections.abc import Iterable
1616
from dataclasses import dataclass, field
1717
from typing import Any, Optional, Union
18+
from datetime import timedelta
1819

1920
import torch.multiprocessing as mp
2021
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
@@ -50,7 +51,7 @@ def __init__(self, trainer: "pl.Trainer"):
5051

5152
def on_trainer_init(
5253
self,
53-
val_check_interval: Optional[Union[int, float]],
54+
val_check_interval: Optional[Union[int, float, str, timedelta, dict]],
5455
reload_dataloaders_every_n_epochs: int,
5556
check_val_every_n_epoch: Optional[int],
5657
) -> None:
@@ -63,8 +64,8 @@ def on_trainer_init(
6364

6465
if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
6566
raise MisconfigurationException(
66-
"`val_check_interval` should be an integer when `check_val_every_n_epoch=None`,"
67-
f" found {val_check_interval!r}."
67+
"`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
68+
"datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
6869
)
6970

7071
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch

src/lightning/pytorch/trainer/setup.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Houses the methods used to set up the Trainer."""
1515

1616
from typing import Optional, Union
17+
from datetime import timedelta
1718

1819
import lightning.pytorch as pl
1920
from lightning.fabric.utilities.warnings import PossibleUserWarning
@@ -40,7 +41,7 @@ def _init_debugging_flags(
4041
limit_predict_batches: Optional[Union[int, float]],
4142
fast_dev_run: Union[int, bool],
4243
overfit_batches: Union[int, float],
43-
val_check_interval: Optional[Union[int, float]],
44+
val_check_interval: Optional[Union[int, float, str, timedelta, dict]],
4445
num_sanity_val_steps: int,
4546
) -> None:
4647
# init debugging flags
@@ -69,6 +70,7 @@ def _init_debugging_flags(
6970
trainer.num_sanity_val_steps = 0
7071
trainer.fit_loop.max_epochs = 1
7172
trainer.val_check_interval = 1.0
73+
trainer._val_check_time_interval = None # time not applicable in fast_dev_run
7274
trainer.check_val_every_n_epoch = 1
7375
trainer.loggers = [DummyLogger()] if trainer.loggers else []
7476
rank_zero_info(
@@ -82,7 +84,16 @@ def _init_debugging_flags(
8284
trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
8385
trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")
8486
trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
85-
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
87+
# Support time-based validation intervals:
88+
# If `val_check_interval` is str/dict/timedelta, parse and store seconds on the trainer
89+
# for the loops to consume.
90+
trainer._val_check_time_interval = None # default
91+
if isinstance(val_check_interval, (str, dict, timedelta)):
92+
trainer._val_check_time_interval = _parse_time_interval_seconds(val_check_interval)
93+
# Keep the numeric scheduler neutral; loops should check the time-based attribute.
94+
trainer.val_check_interval = 1.0
95+
else:
96+
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
8697

8798
if overfit_batches_enabled:
8899
trainer.limit_train_batches = overfit_batches
@@ -187,3 +198,30 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
187198

188199
if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
189200
rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.")
201+
202+
def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float:
203+
if isinstance(value, timedelta):
204+
return value.total_seconds()
205+
if isinstance(value, dict):
206+
td = timedelta(**value)
207+
return td.total_seconds()
208+
if isinstance(value, str):
209+
parts = value.split(":")
210+
if len(parts) != 4:
211+
raise MisconfigurationException(
212+
f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'."
213+
)
214+
d, h, m, s = parts
215+
try:
216+
days = int(d)
217+
hours = int(h)
218+
minutes = int(m)
219+
seconds = int(s)
220+
except ValueError:
221+
raise MisconfigurationException(
222+
f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'."
223+
)
224+
td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
225+
return td.total_seconds()
226+
# Should not happen given the caller guards
227+
raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}")

src/lightning/pytorch/trainer/trainer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109
limit_test_batches: Optional[Union[int, float]] = None,
110110
limit_predict_batches: Optional[Union[int, float]] = None,
111111
overfit_batches: Union[int, float] = 0.0,
112-
val_check_interval: Optional[Union[int, float]] = None,
112+
val_check_interval: Optional[Union[int, float, str, timedelta, dict[str,int]]] = None,
113113
check_val_every_n_epoch: Optional[int] = 1,
114114
num_sanity_val_steps: Optional[int] = None,
115115
log_every_n_steps: Optional[int] = None,
@@ -203,12 +203,20 @@ def __init__(
203203
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
204204
batches. An ``int`` value can only be higher than the number of training batches when
205205
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
206-
across epochs or during iteration-based training.
206+
across epochs or during iteration-based training. Additionally, accepts a time-based duration
207+
as a string "DD:HH:MM:SS", a :class:`datetime.timedelta`, or a dict of kwargs to
208+
:class:`datetime.timedelta`. When time-based, validation triggers once the elapsed wall-clock time
209+
since the last validation exceeds the interval; the check occurs after the current batch
210+
completes, the validation loop runs, and the timer is reset.
207211
Default: ``1.0``.
208212
209213
check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``,
210214
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
211-
to be an integer value.
215+
to be an integer value. When used together with a time-based ``val_check_interval`` and
216+
``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses
217+
before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch)
218+
and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch.
219+
For ``None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment.
212220
Default: ``1``.
213221
214222
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.

0 commit comments

Comments
 (0)