Skip to content

Commit 1fc077b

Browse files
Sohaib-Ahmed21pre-commit-ci[bot]SkafteNickiBorda
authored
Time based validation support (#21071)
* Add wall-clock val_check_interval with epoch alignment and timer reset * Adjust checkpointing frequency when time-based validation is active * Add tests for time based validation * Update docs for time based validation through val_check_interval * changelog * doctest * Parametrize val_check_interval test to include different types --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 3c81316 commit 1fc077b

File tree

11 files changed

+274
-14
lines changed

11 files changed

+274
-14
lines changed

docs/source-pytorch/advanced/speed.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ Validation Within Training Epoch
297297

298298
For large datasets, it's often desirable to check validation multiple times within a training epoch.
299299
Pass in a float to check that often within one training epoch. Pass in an int ``K`` to check every ``K`` training batch.
300-
Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`.
300+
Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. Alternatively, pass a string ("DD:HH:MM:SS"),
301+
a dict of ``datetime.timedelta`` kwargs, or a ``datetime.timedelta`` to check validation after a given amount of wall-clock time.
301302

302303
.. testcode::
303304

@@ -310,6 +311,16 @@ Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`.
310311
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
311312
trainer = Trainer(val_check_interval=100)
312313

314+
# check validation every 15 minutes of wall-clock time
315+
trainer = Trainer(val_check_interval="00:00:15:00")
316+
317+
# alternatively, pass a dict of timedelta kwargs
318+
trainer = Trainer(val_check_interval={"minutes": 1})
319+
320+
# or use a timedelta object directly
321+
from datetime import timedelta
322+
trainer = Trainer(val_check_interval=timedelta(hours=1))
323+
313324
Learn more in our :ref:`trainer_flags` guide.
314325

315326

docs/source-pytorch/common/trainer.rst

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,11 +991,23 @@ val_check_interval
991991
:muted:
992992

993993
How often within one training epoch to check the validation set.
994-
Can specify as float or int.
994+
Can specify as float, int, or a time-based duration.
995995

996996
- pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch.
997997
- pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training
998998
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training.
999+
- pass a ``string`` duration in the format "DD:HH:MM:SS", a ``datetime.timedelta`` object, or a ``dictionary`` of keyword arguments that can be passed
1000+
to ``datetime.timedelta`` for time-based validation. When using a time-based duration, validation will trigger once the elapsed wall-clock time
1001+
since the last validation exceeds the interval. The validation check occurs after the current batch completes, the validation loop runs, and
1002+
the timer resets.
1003+
1004+
**Time-based validation behavior with check_val_every_n_epoch:** When used together with ``val_check_interval`` (time-based) and
1005+
``check_val_every_n_epoch > 1``, validation is aligned to epoch multiples:
1006+
1007+
- If the time-based interval elapses **before** the next multiple-N epoch, validation runs at the start of that epoch (after the first batch),
1008+
and the timer resets.
1009+
- If the interval elapses **during** a multiple-N epoch, validation runs after the current batch.
1010+
- For cases where ``check_val_every_n_epoch=None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment.
9991011

10001012
.. testcode::
10011013

@@ -1013,10 +1025,25 @@ Can specify as float or int.
10131025
# (ie: production cases with streaming data)
10141026
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
10151027

1028+
# check validation every 15 minutes of wall-clock time using a string-based approach
1029+
trainer = Trainer(val_check_interval="00:00:15:00")
1030+
1031+
# check validation every 15 minutes of wall-clock time using a dictionary-based approach
1032+
trainer = Trainer(val_check_interval={"minutes": 15})
1033+
1034+
# check validation every 1 hour of wall-clock time using a dictionary-based approach
1035+
trainer = Trainer(val_check_interval={"hours": 1})
1036+
1037+
# check validation every 1 hour of wall-clock time using a datetime.timedelta object
1038+
from datetime import timedelta
1039+
trainer = Trainer(val_check_interval=timedelta(hours=1))
1040+
1041+
10161042

10171043
.. code-block:: python
10181044
10191045
# Here is the computation to estimate the total number of batches seen within an epoch.
1046+
# This logic applies when `val_check_interval` is specified as an integer or a float.
10201047
10211048
# Find the total number of train batches
10221049
total_train_batches = total_train_samples // (train_batch_size * world_size)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222
- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))
2323

2424

25+
- Added time based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))
26+
27+
2528
### Changed
2629

2730
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class ModelCheckpoint(Checkpoint):
137137
If ``True``, checkpoints are saved at the end of every training epoch.
138138
If ``False``, checkpoints are saved at the end of validation.
139139
If ``None`` (default), checkpointing behavior is determined based on training configuration.
140+
If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
141+
validation.
140142
If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
141143
every training epoch. If there are no validation batches of data, checkpointing will occur at the
142144
end of the training epoch. If there is a non-default number of validation runs per training epoch
@@ -517,6 +519,10 @@ def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
517519
if self._save_on_train_epoch_end is not None:
518520
return self._save_on_train_epoch_end
519521

522+
# time-based validation: always defer saving to validation end
523+
if getattr(trainer, "_val_check_time_interval", None) is not None:
524+
return False
525+
520526
# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded
521527
# so let's not enforce saving at every training epoch end
522528
if trainer.check_val_every_n_epoch != 1:

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import shutil
1717
import sys
18+
import time
1819
from collections import ChainMap, OrderedDict, defaultdict
1920
from collections.abc import Iterable, Iterator
2021
from dataclasses import dataclass
@@ -314,6 +315,9 @@ def on_run_end(self) -> list[_OUT_DICT]:
314315
if self.verbose and self.trainer.is_global_zero:
315316
self._print_results(logged_outputs, self._stage.value)
316317

318+
now = time.monotonic()
319+
self.trainer._last_val_time = now
320+
317321
return logged_outputs
318322

319323
def teardown(self) -> None:

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import time
1516
from dataclasses import dataclass
1617
from typing import Any, Optional, Union
1718

@@ -283,7 +284,13 @@ def setup_data(self) -> None:
283284
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
284285
self._last_train_dl_reload_epoch = trainer.current_epoch
285286

286-
if isinstance(trainer.val_check_interval, int):
287+
# If time-based validation is enabled, disable batch-based scheduling here.
288+
# Use None to clearly signal "no batch-based validation"; wall-time logic will run elsewhere.
289+
if getattr(trainer, "_val_check_time_interval", None) is not None:
290+
trainer.val_check_batch = None
291+
trainer._train_start_time = time.monotonic()
292+
trainer._last_val_time = trainer._train_start_time
293+
elif isinstance(trainer.val_check_interval, int):
287294
trainer.val_check_batch = trainer.val_check_interval
288295
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
289296
raise ValueError(
@@ -299,7 +306,7 @@ def setup_data(self) -> None:
299306
else:
300307
raise MisconfigurationException(
301308
"When using an IterableDataset for `train_dataloader`,"
302-
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
309+
" `Trainer(val_check_interval)` must be time based, `1.0` or an int. An int k specifies"
303310
" checking validation every k training batches."
304311
)
305312
else:

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import contextlib
1515
import math
16+
import time
1617
from collections import OrderedDict
1718
from dataclasses import dataclass
1819
from typing import Any, Optional, Union
@@ -534,11 +535,18 @@ def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool:
534535
# and when the loop allows to stop (min_epochs/steps met)
535536
return True
536537

538+
interval = self.trainer._val_check_time_interval
539+
if interval is not None:
540+
now = time.monotonic()
541+
# if time’s up → tell Trainer to validate
542+
return now - self.trainer._last_val_time >= interval
537543
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
538544
is_val_check_batch = is_last_batch
539545
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
540546
is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0
541547
elif self.trainer.val_check_batch != float("inf"):
548+
# if we got here, we’re in batch-based mode, so this can’t be None
549+
assert self.trainer.val_check_batch is not None
542550
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
543551
# else condition it based on the batch_idx of the current epoch
544552
current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
from collections.abc import Iterable
1616
from dataclasses import dataclass, field
17+
from datetime import timedelta
1718
from typing import Any, Optional, Union
1819

1920
import torch.multiprocessing as mp
@@ -50,7 +51,7 @@ def __init__(self, trainer: "pl.Trainer"):
5051

5152
def on_trainer_init(
5253
self,
53-
val_check_interval: Optional[Union[int, float]],
54+
val_check_interval: Optional[Union[int, float, str, timedelta, dict]],
5455
reload_dataloaders_every_n_epochs: int,
5556
check_val_every_n_epoch: Optional[int],
5657
) -> None:
@@ -63,8 +64,8 @@ def on_trainer_init(
6364

6465
if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
6566
raise MisconfigurationException(
66-
"`val_check_interval` should be an integer when `check_val_every_n_epoch=None`,"
67-
f" found {val_check_interval!r}."
67+
"`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
68+
"datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
6869
)
6970

7071
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch

src/lightning/pytorch/trainer/setup.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Houses the methods used to set up the Trainer."""
1515

16+
from datetime import timedelta
1617
from typing import Optional, Union
1718

1819
import lightning.pytorch as pl
@@ -40,7 +41,7 @@ def _init_debugging_flags(
4041
limit_predict_batches: Optional[Union[int, float]],
4142
fast_dev_run: Union[int, bool],
4243
overfit_batches: Union[int, float],
43-
val_check_interval: Optional[Union[int, float]],
44+
val_check_interval: Optional[Union[int, float, str, timedelta, dict]],
4445
num_sanity_val_steps: int,
4546
) -> None:
4647
# init debugging flags
@@ -69,6 +70,7 @@ def _init_debugging_flags(
6970
trainer.num_sanity_val_steps = 0
7071
trainer.fit_loop.max_epochs = 1
7172
trainer.val_check_interval = 1.0
73+
trainer._val_check_time_interval = None # time not applicable in fast_dev_run
7274
trainer.check_val_every_n_epoch = 1
7375
trainer.loggers = [DummyLogger()] if trainer.loggers else []
7476
rank_zero_info(
@@ -82,7 +84,14 @@ def _init_debugging_flags(
8284
trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
8385
trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")
8486
trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
85-
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
87+
# Support time-based validation intervals:
88+
# If `val_check_interval` is str/dict/timedelta, parse and store seconds on the trainer
89+
# for the loops to consume.
90+
trainer._val_check_time_interval = None # default
91+
if isinstance(val_check_interval, (str, dict, timedelta)):
92+
trainer._val_check_time_interval = _parse_time_interval_seconds(val_check_interval)
93+
else:
94+
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
8695

8796
if overfit_batches_enabled:
8897
trainer.limit_train_batches = overfit_batches
@@ -187,3 +196,65 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
187196

188197
if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
189198
rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.")
199+
200+
201+
def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float:
202+
"""Convert a time interval into seconds.
203+
204+
This helper parses different representations of a time interval and
205+
normalizes them into a float number of seconds.
206+
207+
Supported input formats:
208+
* `timedelta`: The total seconds are returned directly.
209+
* `dict`: A dictionary of keyword arguments accepted by
210+
`datetime.timedelta`, e.g. `{"days": 1, "hours": 2}`.
211+
* `str`: A string in the format `"DD:HH:MM:SS"`, where each
212+
component must be an integer.
213+
214+
Args:
215+
value (Union[str, timedelta, dict]): The time interval to parse.
216+
217+
Returns:
218+
float: The duration represented by `value` in seconds.
219+
220+
Raises:
221+
MisconfigurationException: If the input type is unsupported, the
222+
string format is invalid, or any string component is not an integer.
223+
224+
Examples:
225+
>>> _parse_time_interval_seconds("01:02:03:04")
226+
93784.0
227+
228+
>>> _parse_time_interval_seconds({"hours": 2, "minutes": 30})
229+
9000.0
230+
231+
>>> from datetime import timedelta
232+
>>> _parse_time_interval_seconds(timedelta(days=1, seconds=30))
233+
86430.0
234+
235+
"""
236+
if isinstance(value, timedelta):
237+
return value.total_seconds()
238+
if isinstance(value, dict):
239+
td = timedelta(**value)
240+
return td.total_seconds()
241+
if isinstance(value, str):
242+
parts = value.split(":")
243+
if len(parts) != 4:
244+
raise MisconfigurationException(
245+
f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'."
246+
)
247+
d, h, m, s = parts
248+
try:
249+
days = int(d)
250+
hours = int(h)
251+
minutes = int(m)
252+
seconds = int(s)
253+
except ValueError:
254+
raise MisconfigurationException(
255+
f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'."
256+
)
257+
td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
258+
return td.total_seconds()
259+
# Should not happen given the caller guards
260+
raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}")

src/lightning/pytorch/trainer/trainer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109
limit_test_batches: Optional[Union[int, float]] = None,
110110
limit_predict_batches: Optional[Union[int, float]] = None,
111111
overfit_batches: Union[int, float] = 0.0,
112-
val_check_interval: Optional[Union[int, float]] = None,
112+
val_check_interval: Optional[Union[int, float, str, timedelta, dict[str, int]]] = None,
113113
check_val_every_n_epoch: Optional[int] = 1,
114114
num_sanity_val_steps: Optional[int] = None,
115115
log_every_n_steps: Optional[int] = None,
@@ -203,12 +203,21 @@ def __init__(
203203
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
204204
batches. An ``int`` value can only be higher than the number of training batches when
205205
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
206-
across epochs or during iteration-based training.
206+
across epochs or during iteration-based training. Additionally, accepts a time-based duration
207+
as a string "DD:HH:MM:SS", a :class:`datetime.timedelta`, or a dict of kwargs to
208+
:class:`datetime.timedelta`. When time-based, validation triggers once the elapsed wall-clock time
209+
since the last validation exceeds the interval; the check occurs after the current batch
210+
completes, the validation loop runs, and the timer is reset.
207211
Default: ``1.0``.
208212
209213
check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``,
210214
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
211-
to be an integer value.
215+
to be an integer value. When used together with a time-based ``val_check_interval`` and
216+
``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses
217+
before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch)
218+
and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch.
219+
For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without
220+
additional alignment.
212221
Default: ``1``.
213222
214223
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
@@ -489,7 +498,7 @@ def __init__(
489498
self._logger_connector.on_trainer_init(logger, log_every_n_steps)
490499

491500
# init debugging flags
492-
self.val_check_batch: Union[int, float]
501+
self.val_check_batch: Optional[Union[int, float]] = None
493502
self.val_check_interval: Union[int, float]
494503
self.num_sanity_val_steps: Union[int, float]
495504
self.limit_train_batches: Union[int, float]

0 commit comments

Comments
 (0)