diff --git a/docs/source-pytorch/advanced/speed.rst b/docs/source-pytorch/advanced/speed.rst index 53f2938ab099e..93fe9bfcb458d 100644 --- a/docs/source-pytorch/advanced/speed.rst +++ b/docs/source-pytorch/advanced/speed.rst @@ -297,7 +297,8 @@ Validation Within Training Epoch For large datasets, it's often desirable to check validation multiple times within a training epoch. Pass in a float to check that often within one training epoch. Pass in an int ``K`` to check every ``K`` training batch. -Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. +Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. Alternatively, pass a string ("DD:HH:MM:SS"), +a dict of ``datetime.timedelta`` kwargs, or a ``datetime.timedelta`` to check validation after a given amount of wall-clock time. .. testcode:: @@ -310,6 +311,16 @@ Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. # check every 100 train batches (ie: for IterableDatasets or fixed frequency) trainer = Trainer(val_check_interval=100) + # check validation every 15 minutes of wall-clock time + trainer = Trainer(val_check_interval="00:00:15:00") + + # alternatively, pass a dict of timedelta kwargs + trainer = Trainer(val_check_interval={"minutes": 1}) + + # or use a timedelta object directly + from datetime import timedelta + trainer = Trainer(val_check_interval=timedelta(hours=1)) + Learn more in our :ref:`trainer_flags` guide. diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index bb63b64854f7c..a3bdb6bb7b2de 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -991,11 +991,23 @@ val_check_interval :muted: How often within one training epoch to check the validation set. -Can specify as float or int. +Can specify as float, int, or a time-based duration. - pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. - pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training. +- pass a ``string`` duration in the format "DD:HH:MM:SS", a ``datetime.timedelta`` object, or a ``dictionary`` of keyword arguments that can be passed + to ``datetime.timedelta`` for time-based validation. When using a time-based duration, validation will trigger once the elapsed wall-clock time + since the last validation exceeds the interval. The validation check occurs after the current batch completes, the validation loop runs, and + the timer resets. + +**Time-based validation behavior with check_val_every_n_epoch:** When used together with ``val_check_interval`` (time-based) and +``check_val_every_n_epoch > 1``, validation is aligned to epoch multiples: + +- If the time-based interval elapses **before** the next multiple-N epoch, validation runs at the start of that epoch (after the first batch), + and the timer resets. +- If the interval elapses **during** a multiple-N epoch, validation runs after the current batch. +- For cases where ``check_val_every_n_epoch=None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment. .. testcode:: @@ -1013,10 +1025,25 @@ Can specify as float or int. # (ie: production cases with streaming data) trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None) + # check validation every 15 minutes of wall-clock time using a string-based approach + trainer = Trainer(val_check_interval="00:00:15:00") + + # check validation every 15 minutes of wall-clock time using a dictionary-based approach + trainer = Trainer(val_check_interval={"minutes": 15}) + + # check validation every 1 hour of wall-clock time using a dictionary-based approach + trainer = Trainer(val_check_interval={"hours": 1}) + + # check validation every 1 hour of wall-clock time using a datetime.timedelta object + from datetime import timedelta + trainer = Trainer(val_check_interval=timedelta(hours=1)) + + .. code-block:: python # Here is the computation to estimate the total number of batches seen within an epoch. + # This logic applies when `val_check_interval` is specified as an integer or a float. # Find the total number of train batches total_train_batches = total_train_samples // (train_batch_size * world_size) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b3c22a1b1f11f..2f60d0ac96c1d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146)) +- Added time based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 452e8bdecbba3..dfc0cebb8d07d 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -137,6 +137,8 @@ class ModelCheckpoint(Checkpoint): If ``True``, checkpoints are saved at the end of every training epoch. If ``False``, checkpoints are saved at the end of validation. If ``None`` (default), checkpointing behavior is determined based on training configuration. + If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after + validation. If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of every training epoch. If there are no validation batches of data, checkpointing will occur at the end of the training epoch. If there is a non-default number of validation runs per training epoch @@ -517,6 +519,10 @@ def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: if self._save_on_train_epoch_end is not None: return self._save_on_train_epoch_end + # time-based validation: always defer saving to validation end + if getattr(trainer, "_val_check_time_interval", None) is not None: + return False + # if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded # so let's not enforce saving at every training epoch end if trainer.check_val_every_n_epoch != 1: diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index b1e9edfaf7220..6036e57cf59ae 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,6 +15,7 @@ import os import shutil import sys +import time from collections import ChainMap, OrderedDict, defaultdict from collections.abc import Iterable, Iterator from dataclasses import dataclass @@ -314,6 +315,9 @@ def on_run_end(self) -> list[_OUT_DICT]: if self.verbose and self.trainer.is_global_zero: self._print_results(logged_outputs, self._stage.value) + now = time.monotonic() + self.trainer._last_val_time = now + return logged_outputs def teardown(self) -> None: diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index f25c33359a78a..8bb123939dc20 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import time from dataclasses import dataclass from typing import Any, Optional, Union @@ -283,7 +284,13 @@ def setup_data(self) -> None: # store epoch of dataloader reset for reload_dataloaders_every_n_epochs self._last_train_dl_reload_epoch = trainer.current_epoch - if isinstance(trainer.val_check_interval, int): + # If time-based validation is enabled, disable batch-based scheduling here. + # Use None to clearly signal "no batch-based validation"; wall-time logic will run elsewhere. + if getattr(trainer, "_val_check_time_interval", None) is not None: + trainer.val_check_batch = None + trainer._train_start_time = time.monotonic() + trainer._last_val_time = trainer._train_start_time + elif isinstance(trainer.val_check_interval, int): trainer.val_check_batch = trainer.val_check_interval if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None: raise ValueError( @@ -299,7 +306,7 @@ def setup_data(self) -> None: else: raise MisconfigurationException( "When using an IterableDataset for `train_dataloader`," - " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" + " `Trainer(val_check_interval)` must be time based, `1.0` or an int. An int k specifies" " checking validation every k training batches." ) else: diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index c0a57ae12c4d1..3d01780b705fe 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib import math +import time from collections import OrderedDict from dataclasses import dataclass from typing import Any, Optional, Union @@ -534,11 +535,18 @@ def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool: # and when the loop allows to stop (min_epochs/steps met) return True + interval = self.trainer._val_check_time_interval + if interval is not None: + now = time.monotonic() + # if time’s up → tell Trainer to validate + return now - self.trainer._last_val_time >= interval # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): + # if we got here, we’re in batch-based mode, so this can’t be None + assert self.trainer.val_check_batch is not None # if `check_val_every_n_epoch is `None`, run a validation loop every n training batches # else condition it based on the batch_idx of the current epoch current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 841d78b457d48..240dae6296c1f 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -14,6 +14,7 @@ import os from collections.abc import Iterable from dataclasses import dataclass, field +from datetime import timedelta from typing import Any, Optional, Union import torch.multiprocessing as mp @@ -50,7 +51,7 @@ def __init__(self, trainer: "pl.Trainer"): def on_trainer_init( self, - val_check_interval: Optional[Union[int, float]], + val_check_interval: Optional[Union[int, float, str, timedelta, dict]], reload_dataloaders_every_n_epochs: int, check_val_every_n_epoch: Optional[int], ) -> None: @@ -63,8 +64,8 @@ def on_trainer_init( if check_val_every_n_epoch is None and isinstance(val_check_interval, float): raise MisconfigurationException( - "`val_check_interval` should be an integer when `check_val_every_n_epoch=None`," - f" found {val_check_interval!r}." + "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', " + "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`." ) self.trainer.check_val_every_n_epoch = check_val_every_n_epoch diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 00b546b252ac8..73591d30417b8 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -13,6 +13,7 @@ # limitations under the License. """Houses the methods used to set up the Trainer.""" +from datetime import timedelta from typing import Optional, Union import lightning.pytorch as pl @@ -40,7 +41,7 @@ def _init_debugging_flags( limit_predict_batches: Optional[Union[int, float]], fast_dev_run: Union[int, bool], overfit_batches: Union[int, float], - val_check_interval: Optional[Union[int, float]], + val_check_interval: Optional[Union[int, float, str, timedelta, dict]], num_sanity_val_steps: int, ) -> None: # init debugging flags @@ -69,6 +70,7 @@ def _init_debugging_flags( trainer.num_sanity_val_steps = 0 trainer.fit_loop.max_epochs = 1 trainer.val_check_interval = 1.0 + trainer._val_check_time_interval = None # time not applicable in fast_dev_run trainer.check_val_every_n_epoch = 1 trainer.loggers = [DummyLogger()] if trainer.loggers else [] rank_zero_info( @@ -82,7 +84,14 @@ def _init_debugging_flags( trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps - trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") + # Support time-based validation intervals: + # If `val_check_interval` is str/dict/timedelta, parse and store seconds on the trainer + # for the loops to consume. + trainer._val_check_time_interval = None # default + if isinstance(val_check_interval, (str, dict, timedelta)): + trainer._val_check_time_interval = _parse_time_interval_seconds(val_check_interval) + else: + trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") if overfit_batches_enabled: trainer.limit_train_batches = overfit_batches @@ -187,3 +196,65 @@ def _log_device_info(trainer: "pl.Trainer") -> None: if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator): rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.") + + +def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float: + """Convert a time interval into seconds. + + This helper parses different representations of a time interval and + normalizes them into a float number of seconds. + + Supported input formats: + * `timedelta`: The total seconds are returned directly. + * `dict`: A dictionary of keyword arguments accepted by + `datetime.timedelta`, e.g. `{"days": 1, "hours": 2}`. + * `str`: A string in the format `"DD:HH:MM:SS"`, where each + component must be an integer. + + Args: + value (Union[str, timedelta, dict]): The time interval to parse. + + Returns: + float: The duration represented by `value` in seconds. + + Raises: + MisconfigurationException: If the input type is unsupported, the + string format is invalid, or any string component is not an integer. + + Examples: + >>> _parse_time_interval_seconds("01:02:03:04") + 93784.0 + + >>> _parse_time_interval_seconds({"hours": 2, "minutes": 30}) + 9000.0 + + >>> from datetime import timedelta + >>> _parse_time_interval_seconds(timedelta(days=1, seconds=30)) + 86430.0 + + """ + if isinstance(value, timedelta): + return value.total_seconds() + if isinstance(value, dict): + td = timedelta(**value) + return td.total_seconds() + if isinstance(value, str): + parts = value.split(":") + if len(parts) != 4: + raise MisconfigurationException( + f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'." + ) + d, h, m, s = parts + try: + days = int(d) + hours = int(h) + minutes = int(m) + seconds = int(s) + except ValueError: + raise MisconfigurationException( + f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'." + ) + td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) + return td.total_seconds() + # Should not happen given the caller guards + raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}") diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index ddad7e9245cd5..5768c507e2e3f 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -109,7 +109,7 @@ def __init__( limit_test_batches: Optional[Union[int, float]] = None, limit_predict_batches: Optional[Union[int, float]] = None, overfit_batches: Union[int, float] = 0.0, - val_check_interval: Optional[Union[int, float]] = None, + val_check_interval: Optional[Union[int, float, str, timedelta, dict[str, int]]] = None, check_val_every_n_epoch: Optional[int] = 1, num_sanity_val_steps: Optional[int] = None, log_every_n_steps: Optional[int] = None, @@ -203,12 +203,21 @@ def __init__( after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches - across epochs or during iteration-based training. + across epochs or during iteration-based training. Additionally, accepts a time-based duration + as a string "DD:HH:MM:SS", a :class:`datetime.timedelta`, or a dict of kwargs to + :class:`datetime.timedelta`. When time-based, validation triggers once the elapsed wall-clock time + since the last validation exceeds the interval; the check occurs after the current batch + completes, the validation loop runs, and the timer is reset. Default: ``1.0``. check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` - to be an integer value. + to be an integer value. When used together with a time-based ``val_check_interval`` and + ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses + before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) + and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. + For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without + additional alignment. Default: ``1``. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. @@ -489,7 +498,7 @@ def __init__( self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags - self.val_check_batch: Union[int, float] + self.val_check_batch: Optional[Union[int, float]] = None self.val_check_interval: Union[int, float] self.num_sanity_val_steps: Union[int, float] self.limit_train_batches: Union[int, float] diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index b6cc446cb0840..53c95e40d2c20 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +import time +from datetime import timedelta +from unittest.mock import patch import pytest from torch.utils.data import DataLoader @@ -127,9 +131,118 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch(): """Test that an exception is raised when `val_check_interval` is set to float with `check_val_every_n_epoch=None`""" with pytest.raises( - MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`" + MisconfigurationException, + match=re.escape( + "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', " + "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`." + ), ): Trainer( val_check_interval=0.5, check_val_every_n_epoch=None, ) + + +@pytest.mark.parametrize( + "interval", + [ + "00:00:00:02", + {"seconds": 2}, + timedelta(seconds=2), + ], +) +def test_time_based_val_check_interval(tmp_path, interval): + call_count = {"count": 0} + + def fake_time(): + result = call_count["count"] + call_count["count"] += 2 + return result + + with patch("time.monotonic", side_effect=fake_time): + trainer = Trainer( + default_root_dir=tmp_path, + logger=False, + enable_checkpointing=False, + max_epochs=1, + max_steps=5, # 5 steps: simulate 10s total wall-clock time + limit_val_batches=1, + val_check_interval=interval, # every 2s + ) + model = BoringModel() + trainer.fit(model) + + # Assert 5 validations happened + val_runs = trainer.fit_loop.epoch_loop.val_loop.batch_progress.total.completed + # The number of validation runs should be equal to the number of times we called fake_time + assert val_runs == 5, f"Expected 5 validations, got {val_runs}" + + +@pytest.mark.parametrize( + ("check_val_every_n_epoch", "val_check_interval", "epoch_duration", "expected_val_batches", "description"), + [ + (None, "00:00:00:04", 2, [0, 1, 0, 1, 0], "val_check_interval timer only, no epoch gating"), + (1, "00:00:00:06", 8, [1, 1, 2, 1, 1], "val_check_interval timer only, no epoch gating"), + (2, "00:00:00:06", 9, [0, 2, 0, 2, 0], "epoch gating, timer shorter than epoch"), + (2, "00:00:00:03", 9, [0, 3, 0, 3, 0], "epoch gating, timer much shorter than epoch"), + (2, "00:00:00:20", 9, [0, 0, 0, 1, 0], "epoch gating, timer longer than epoch"), + ], +) +def test_time_and_epoch_gated_val_check( + tmp_path, check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description +): + call_count = {"count": 0} + + # Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch) + def fake_time(): + result = call_count["count"] + call_count["count"] += 1 + return result + + # Custom model to record when validation happens (on what epoch) + class TestModel(BoringModel): + val_batches = [] + val_epoch_calls = 0 + + def on_train_batch_end(self, *args, **kwargs): + if ( + isinstance(self.trainer.check_val_every_n_epoch, int) + and self.trainer.check_val_every_n_epoch > 1 + and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0 + ): + time.monotonic() + + def on_train_epoch_end(self, *args, **kwargs): + print(trainer.fit_loop.epoch_loop.val_loop.batch_progress.current.completed) + self.val_batches.append(trainer.fit_loop.epoch_loop.val_loop.batch_progress.total.completed) + + def on_validation_epoch_start(self) -> None: + self.val_epoch_calls += 1 + + max_epochs = 5 + max_steps = max_epochs * epoch_duration + limit_train_batches = epoch_duration + + trainer_kwargs = { + "default_root_dir": tmp_path, + "logger": False, + "enable_checkpointing": False, + "max_epochs": max_epochs, + "max_steps": max_steps, + "limit_val_batches": 1, + "limit_train_batches": limit_train_batches, + "val_check_interval": val_check_interval, + "check_val_every_n_epoch": check_val_every_n_epoch, + } + + with patch("time.monotonic", side_effect=fake_time): + model = TestModel() + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) + + # Validate which epochs validation happened + assert model.val_batches == expected_val_batches, ( + f"\nFAILED: {description}" + f"\nExpected validation at batches: {expected_val_batches}," + f"\nGot: {model.val_batches, model.val_epoch_calls}\n" + )