Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
98f608f
Add wall-clock val_check_interval with epoch alignment and timer reset
Sohaib-Ahmed21 Aug 14, 2025
e9b66cb
Adjust checkpointing frequency when time-based validation is active
Sohaib-Ahmed21 Aug 14, 2025
f29ee07
Add tests for time based validation
Sohaib-Ahmed21 Aug 14, 2025
655bb77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2025
bbeb0de
Address pre-commit hook failures and add comment to explain checkpoin…
Sohaib-Ahmed21 Aug 14, 2025
3dc4283
Merge branch 'feature/13324_validation-interval' of https://github.co…
Sohaib-Ahmed21 Aug 14, 2025
228f52d
Merge branch 'master' into feature/13324_validation-interval
SkafteNicki Aug 15, 2025
74368ae
Merge branch 'master' into feature/13324_validation-interval
Sohaib-Ahmed21 Aug 20, 2025
44e88a4
Merge branch 'master' into feature/13324_validation-interval
Sohaib-Ahmed21 Aug 26, 2025
52695b0
Fix mypy check failure and val_check_interval initialization in case …
Sohaib-Ahmed21 Aug 26, 2025
4bcd5e8
Merge branch 'master' into feature/13324_validation-interval
Sohaib-Ahmed21 Aug 27, 2025
a610ad2
Merge branch 'master' into feature/13324_validation-interval
Sohaib-Ahmed21 Aug 29, 2025
64eb3c7
Update docs for time based validation through val_check_interval
Sohaib-Ahmed21 Aug 30, 2025
d1411ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2025
727ce57
Merge branch 'master' into feature/13324_validation-interval
Sohaib-Ahmed21 Aug 30, 2025
74176b5
Solve mypy and docs import failures
Sohaib-Ahmed21 Sep 1, 2025
591adfc
Merge branch 'feature/13324_validation-interval' of https://github.co…
Sohaib-Ahmed21 Sep 1, 2025
06f0c61
Merge branch 'master' into feature/13324_validation-interval
Sohaib-Ahmed21 Sep 1, 2025
b67302e
Merge branch 'master' into feature/13324_validation-interval
SkafteNicki Sep 3, 2025
4693672
changelog
SkafteNicki Sep 3, 2025
1b09699
doctest
Borda Sep 3, 2025
69baf1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2025
d4c83e1
Merge branch 'master' into feature/13324_validation-interval
Borda Sep 3, 2025
dbefc8e
Merge branch 'master' into feature/13324_validation-interval
Sohaib-Ahmed21 Sep 4, 2025
eaa8a23
Parametrize val_check_interval test to include different types
Sohaib-Ahmed21 Sep 4, 2025
01c4fef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/source-pytorch/advanced/speed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ Validation Within Training Epoch

For large datasets, it's often desirable to check validation multiple times within a training epoch.
Pass in a float to check that often within one training epoch. Pass in an int ``K`` to check every ``K`` training batch.
Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`.
Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. Alternatively, pass a string ("DD:HH:MM:SS"),
a dict of ``datetime.timedelta`` kwargs, or a ``datetime.timedelta`` to check validation after a given amount of wall-clock time.

.. testcode::

Expand All @@ -310,6 +311,16 @@ Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`.
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
trainer = Trainer(val_check_interval=100)

# check validation every 15 minutes of wall-clock time
trainer = Trainer(val_check_interval="00:00:15:00")

# alternatively, pass a dict of timedelta kwargs
trainer = Trainer(val_check_interval={"minutes": 1})

# or use a timedelta object directly
from datetime import timedelta
trainer = Trainer(val_check_interval=timedelta(hours=1))

Learn more in our :ref:`trainer_flags` guide.


Expand Down
29 changes: 28 additions & 1 deletion docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -989,11 +989,23 @@ val_check_interval
:muted:

How often within one training epoch to check the validation set.
Can specify as float or int.
Can specify as float, int, or a time-based duration.

- pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch.
- pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training.
- pass a ``string`` duration in the format "DD:HH:MM:SS", a ``datetime.timedelta`` object, or a ``dictionary`` of keyword arguments that can be passed
to ``datetime.timedelta`` for time-based validation. When using a time-based duration, validation will trigger once the elapsed wall-clock time
since the last validation exceeds the interval. The validation check occurs after the current batch completes, the validation loop runs, and
the timer resets.

**Time-based validation behavior with check_val_every_n_epoch:** When used together with ``val_check_interval`` (time-based) and
``check_val_every_n_epoch > 1``, validation is aligned to epoch multiples:

- If the time-based interval elapses **before** the next multiple-N epoch, validation runs at the start of that epoch (after the first batch),
and the timer resets.
- If the interval elapses **during** a multiple-N epoch, validation runs after the current batch.
- For cases where ``check_val_every_n_epoch=None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment.

.. testcode::

Expand All @@ -1011,10 +1023,25 @@ Can specify as float or int.
# (ie: production cases with streaming data)
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)

# check validation every 15 minutes of wall-clock time using a string-based approach
trainer = Trainer(val_check_interval="00:00:15:00")

# check validation every 15 minutes of wall-clock time using a dictionary-based approach
trainer = Trainer(val_check_interval={"minutes": 15})

# check validation every 1 hour of wall-clock time using a dictionary-based approach
trainer = Trainer(val_check_interval={"hours": 1})

# check validation every 1 hour of wall-clock time using a datetime.timedelta object
from datetime import timedelta
trainer = Trainer(val_check_interval=timedelta(hours=1))



.. code-block:: python

# Here is the computation to estimate the total number of batches seen within an epoch.
# This logic applies when `val_check_interval` is specified as an integer or a float.

# Find the total number of train batches
total_train_batches = total_train_samples // (train_batch_size * world_size)
Expand Down
6 changes: 6 additions & 0 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class ModelCheckpoint(Checkpoint):
If ``True``, checkpoints are saved at the end of every training epoch.
If ``False``, checkpoints are saved at the end of validation.
If ``None`` (default), checkpointing behavior is determined based on training configuration.
If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
validation.
If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
every training epoch. If there are no validation batches of data, checkpointing will occur at the
end of the training epoch. If there is a non-default number of validation runs per training epoch
Expand Down Expand Up @@ -517,6 +519,10 @@ def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
if self._save_on_train_epoch_end is not None:
return self._save_on_train_epoch_end

# time-based validation: always defer saving to validation end
if getattr(trainer, "_val_check_time_interval", None) is not None:
return False

# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded
# so let's not enforce saving at every training epoch end
if trainer.check_val_every_n_epoch != 1:
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import shutil
import sys
import time
from collections import ChainMap, OrderedDict, defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
Expand Down Expand Up @@ -314,6 +315,9 @@ def on_run_end(self) -> list[_OUT_DICT]:
if self.verbose and self.trainer.is_global_zero:
self._print_results(logged_outputs, self._stage.value)

now = time.monotonic()
self.trainer._last_val_time = now

return logged_outputs

def teardown(self) -> None:
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from dataclasses import dataclass
from typing import Any, Optional, Union

Expand Down Expand Up @@ -283,7 +284,13 @@ def setup_data(self) -> None:
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = trainer.current_epoch

if isinstance(trainer.val_check_interval, int):
# If time-based validation is enabled, disable batch-based scheduling here.
# Use None to clearly signal "no batch-based validation"; wall-time logic will run elsewhere.
if getattr(trainer, "_val_check_time_interval", None) is not None:
trainer.val_check_batch = None
trainer._train_start_time = time.monotonic()
trainer._last_val_time = trainer._train_start_time
elif isinstance(trainer.val_check_interval, int):
trainer.val_check_batch = trainer.val_check_interval
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
raise ValueError(
Expand All @@ -299,7 +306,7 @@ def setup_data(self) -> None:
else:
raise MisconfigurationException(
"When using an IterableDataset for `train_dataloader`,"
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
" `Trainer(val_check_interval)` must be time based, `1.0` or an int. An int k specifies"
" checking validation every k training batches."
)
else:
Expand Down
8 changes: 8 additions & 0 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
import math
import time
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Optional, Union
Expand Down Expand Up @@ -534,11 +535,18 @@ def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool:
# and when the loop allows to stop (min_epochs/steps met)
return True

interval = self.trainer._val_check_time_interval
if interval is not None:
now = time.monotonic()
# if time’s up → tell Trainer to validate
return now - self.trainer._last_val_time >= interval
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = is_last_batch
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
# if we got here, we’re in batch-based mode, so this can’t be None
assert self.trainer.val_check_batch is not None
# if `check_val_every_n_epoch is `None`, run a validation loop every n training batches
# else condition it based on the batch_idx of the current epoch
current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx
Expand Down
7 changes: 4 additions & 3 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
from collections.abc import Iterable
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Optional, Union

import torch.multiprocessing as mp
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(self, trainer: "pl.Trainer"):

def on_trainer_init(
self,
val_check_interval: Optional[Union[int, float]],
val_check_interval: Optional[Union[int, float, str, timedelta, dict]],
reload_dataloaders_every_n_epochs: int,
check_val_every_n_epoch: Optional[int],
) -> None:
Expand All @@ -63,8 +64,8 @@ def on_trainer_init(

if check_val_every_n_epoch is None and isinstance(val_check_interval, float):
raise MisconfigurationException(
"`val_check_interval` should be an integer when `check_val_every_n_epoch=None`,"
f" found {val_check_interval!r}."
"`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
"datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
)

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
Expand Down
41 changes: 39 additions & 2 deletions src/lightning/pytorch/trainer/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Houses the methods used to set up the Trainer."""

from datetime import timedelta
from typing import Optional, Union

import lightning.pytorch as pl
Expand Down Expand Up @@ -40,7 +41,7 @@ def _init_debugging_flags(
limit_predict_batches: Optional[Union[int, float]],
fast_dev_run: Union[int, bool],
overfit_batches: Union[int, float],
val_check_interval: Optional[Union[int, float]],
val_check_interval: Optional[Union[int, float, str, timedelta, dict]],
num_sanity_val_steps: int,
) -> None:
# init debugging flags
Expand Down Expand Up @@ -69,6 +70,7 @@ def _init_debugging_flags(
trainer.num_sanity_val_steps = 0
trainer.fit_loop.max_epochs = 1
trainer.val_check_interval = 1.0
trainer._val_check_time_interval = None # time not applicable in fast_dev_run
trainer.check_val_every_n_epoch = 1
trainer.loggers = [DummyLogger()] if trainer.loggers else []
rank_zero_info(
Expand All @@ -82,7 +84,14 @@ def _init_debugging_flags(
trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")
trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
# Support time-based validation intervals:
# If `val_check_interval` is str/dict/timedelta, parse and store seconds on the trainer
# for the loops to consume.
trainer._val_check_time_interval = None # default
if isinstance(val_check_interval, (str, dict, timedelta)):
trainer._val_check_time_interval = _parse_time_interval_seconds(val_check_interval)
else:
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")

if overfit_batches_enabled:
trainer.limit_train_batches = overfit_batches
Expand Down Expand Up @@ -187,3 +196,31 @@ def _log_device_info(trainer: "pl.Trainer") -> None:

if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.")


def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float:
if isinstance(value, timedelta):
return value.total_seconds()
if isinstance(value, dict):
td = timedelta(**value)
return td.total_seconds()
if isinstance(value, str):
parts = value.split(":")
if len(parts) != 4:
raise MisconfigurationException(
f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'."
)
d, h, m, s = parts
try:
days = int(d)
hours = int(h)
minutes = int(m)
seconds = int(s)
except ValueError:
raise MisconfigurationException(
f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'."
)
td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
return td.total_seconds()
# Should not happen given the caller guards
raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}")
17 changes: 13 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
limit_test_batches: Optional[Union[int, float]] = None,
limit_predict_batches: Optional[Union[int, float]] = None,
overfit_batches: Union[int, float] = 0.0,
val_check_interval: Optional[Union[int, float]] = None,
val_check_interval: Optional[Union[int, float, str, timedelta, dict[str, int]]] = None,
check_val_every_n_epoch: Optional[int] = 1,
num_sanity_val_steps: Optional[int] = None,
log_every_n_steps: Optional[int] = None,
Expand Down Expand Up @@ -203,12 +203,21 @@ def __init__(
after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
batches. An ``int`` value can only be higher than the number of training batches when
``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
across epochs or during iteration-based training.
across epochs or during iteration-based training. Additionally, accepts a time-based duration
as a string "DD:HH:MM:SS", a :class:`datetime.timedelta`, or a dict of kwargs to
:class:`datetime.timedelta`. When time-based, validation triggers once the elapsed wall-clock time
since the last validation exceeds the interval; the check occurs after the current batch
completes, the validation loop runs, and the timer is reset.
Default: ``1.0``.

check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``,
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
to be an integer value.
to be an integer value. When used together with a time-based ``val_check_interval`` and
``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses
before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch)
and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch.
For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without
additional alignment.
Default: ``1``.

num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Expand Down Expand Up @@ -489,7 +498,7 @@ def __init__(
self._logger_connector.on_trainer_init(logger, log_every_n_steps)

# init debugging flags
self.val_check_batch: Union[int, float]
self.val_check_batch: Optional[Union[int, float]] = None
self.val_check_interval: Union[int, float]
self.num_sanity_val_steps: Union[int, float]
self.limit_train_batches: Union[int, float]
Expand Down
Loading
Loading