From 98f608feabb5d421bd775828eae6387832c9eab1 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 14 Aug 2025 17:15:44 +0500 Subject: [PATCH 01/14] Add wall-clock val_check_interval with epoch alignment and timer reset --- .../pytorch/loops/evaluation_loop.py | 4 ++ src/lightning/pytorch/loops/fit_loop.py | 11 ++++- .../pytorch/loops/training_epoch_loop.py | 8 ++++ .../trainer/connectors/data_connector.py | 7 ++-- src/lightning/pytorch/trainer/setup.py | 42 ++++++++++++++++++- src/lightning/pytorch/trainer/trainer.py | 14 +++++-- 6 files changed, 76 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index b1e9edfaf7220..e04fdcc00c581 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -19,6 +19,7 @@ from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import Any, Optional, Union +import time from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -313,6 +314,9 @@ def on_run_end(self) -> list[_OUT_DICT]: if self.verbose and self.trainer.is_global_zero: self._print_results(logged_outputs, self._stage.value) + + now = time.monotonic() + self.trainer._last_val_time = now return logged_outputs diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 31d6724a043a3..ca6d1a5f54b7b 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -14,6 +14,7 @@ import logging from dataclasses import dataclass from typing import Any, Optional, Union +import time import torch from typing_extensions import override @@ -283,7 +284,13 @@ def setup_data(self) -> None: # store epoch of dataloader reset for reload_dataloaders_every_n_epochs self._last_train_dl_reload_epoch = trainer.current_epoch - if isinstance(trainer.val_check_interval, int): + # If time-based validation is enabled, disable batch-based scheduling here. + # Use None to clearly signal "no batch-based validation"; wall-time logic will run elsewhere. + if getattr(trainer, "_val_check_time_interval", None) is not None: + trainer.val_check_batch = None + trainer._train_start_time = time.monotonic() + trainer._last_val_time = trainer._train_start_time + elif isinstance(trainer.val_check_interval, int): trainer.val_check_batch = trainer.val_check_interval if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None: raise ValueError( @@ -299,7 +306,7 @@ def setup_data(self) -> None: else: raise MisconfigurationException( "When using an IterableDataset for `train_dataloader`," - " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" + " `Trainer(val_check_interval)` must be time based, `1.0` or an int. An int k specifies" " checking validation every k training batches." ) else: diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index c0a57ae12c4d1..eba200e1b965a 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -16,6 +16,7 @@ from collections import OrderedDict from dataclasses import dataclass from typing import Any, Optional, Union +import time import torch from typing_extensions import override @@ -534,6 +535,13 @@ def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool: # and when the loop allows to stop (min_epochs/steps met) return True + interval = self.trainer._val_check_time_interval + if interval is not None: + now = time.monotonic() + if now - self.trainer._last_val_time >= interval: + # time’s up → tell Trainer to validate + return True + return False # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 841d78b457d48..cf8995d3df6e2 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -15,6 +15,7 @@ from collections.abc import Iterable from dataclasses import dataclass, field from typing import Any, Optional, Union +from datetime import timedelta import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler @@ -50,7 +51,7 @@ def __init__(self, trainer: "pl.Trainer"): def on_trainer_init( self, - val_check_interval: Optional[Union[int, float]], + val_check_interval: Optional[Union[int, float, str, timedelta, dict]], reload_dataloaders_every_n_epochs: int, check_val_every_n_epoch: Optional[int], ) -> None: @@ -63,8 +64,8 @@ def on_trainer_init( if check_val_every_n_epoch is None and isinstance(val_check_interval, float): raise MisconfigurationException( - "`val_check_interval` should be an integer when `check_val_every_n_epoch=None`," - f" found {val_check_interval!r}." + "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', " + "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`." ) self.trainer.check_val_every_n_epoch = check_val_every_n_epoch diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 00b546b252ac8..df0d2037716cd 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -14,6 +14,7 @@ """Houses the methods used to set up the Trainer.""" from typing import Optional, Union +from datetime import timedelta import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -40,7 +41,7 @@ def _init_debugging_flags( limit_predict_batches: Optional[Union[int, float]], fast_dev_run: Union[int, bool], overfit_batches: Union[int, float], - val_check_interval: Optional[Union[int, float]], + val_check_interval: Optional[Union[int, float, str, timedelta, dict]], num_sanity_val_steps: int, ) -> None: # init debugging flags @@ -69,6 +70,7 @@ def _init_debugging_flags( trainer.num_sanity_val_steps = 0 trainer.fit_loop.max_epochs = 1 trainer.val_check_interval = 1.0 + trainer._val_check_time_interval = None # time not applicable in fast_dev_run trainer.check_val_every_n_epoch = 1 trainer.loggers = [DummyLogger()] if trainer.loggers else [] rank_zero_info( @@ -82,7 +84,16 @@ def _init_debugging_flags( trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps - trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") + # Support time-based validation intervals: + # If `val_check_interval` is str/dict/timedelta, parse and store seconds on the trainer + # for the loops to consume. + trainer._val_check_time_interval = None # default + if isinstance(val_check_interval, (str, dict, timedelta)): + trainer._val_check_time_interval = _parse_time_interval_seconds(val_check_interval) + # Keep the numeric scheduler neutral; loops should check the time-based attribute. + trainer.val_check_interval = 1.0 + else: + trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") if overfit_batches_enabled: trainer.limit_train_batches = overfit_batches @@ -187,3 +198,30 @@ def _log_device_info(trainer: "pl.Trainer") -> None: if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator): rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.") + +def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float: + if isinstance(value, timedelta): + return value.total_seconds() + if isinstance(value, dict): + td = timedelta(**value) + return td.total_seconds() + if isinstance(value, str): + parts = value.split(":") + if len(parts) != 4: + raise MisconfigurationException( + f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'." + ) + d, h, m, s = parts + try: + days = int(d) + hours = int(h) + minutes = int(m) + seconds = int(s) + except ValueError: + raise MisconfigurationException( + f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'." + ) + td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) + return td.total_seconds() + # Should not happen given the caller guards + raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}") \ No newline at end of file diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 732521b1b0ce7..e6b2e64759ce0 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -109,7 +109,7 @@ def __init__( limit_test_batches: Optional[Union[int, float]] = None, limit_predict_batches: Optional[Union[int, float]] = None, overfit_batches: Union[int, float] = 0.0, - val_check_interval: Optional[Union[int, float]] = None, + val_check_interval: Optional[Union[int, float, str, timedelta, dict[str,int]]] = None, check_val_every_n_epoch: Optional[int] = 1, num_sanity_val_steps: Optional[int] = None, log_every_n_steps: Optional[int] = None, @@ -203,12 +203,20 @@ def __init__( after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches - across epochs or during iteration-based training. + across epochs or during iteration-based training. Additionally, accepts a time-based duration + as a string "DD:HH:MM:SS", a :class:`datetime.timedelta`, or a dict of kwargs to + :class:`datetime.timedelta`. When time-based, validation triggers once the elapsed wall-clock time + since the last validation exceeds the interval; the check occurs after the current batch + completes, the validation loop runs, and the timer is reset. Default: ``1.0``. check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` - to be an integer value. + to be an integer value. When used together with a time-based ``val_check_interval`` and + ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses + before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) + and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. + For ``None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment. Default: ``1``. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. From e9b66cb4aeef758e63c4fc328ba989c276acd375 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 14 Aug 2025 17:21:30 +0500 Subject: [PATCH 02/14] Adjust checkpointing frequency when time-based validation is active --- src/lightning/pytorch/callbacks/model_checkpoint.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 68fed2ff82d31..20894d4946154 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -467,6 +467,10 @@ def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: if self._save_on_train_epoch_end is not None: return self._save_on_train_epoch_end + # time-based validation: always defer saving to validation end + if getattr(trainer, "_val_check_time_interval", None) is not None: + return False + # if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded # so let's not enforce saving at every training epoch end if trainer.check_val_every_n_epoch != 1: From f29ee077536c78cb95750ac2a9db6d5693da1538 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 14 Aug 2025 17:22:14 +0500 Subject: [PATCH 03/14] Add tests for time based validation --- .../trainer/flags/test_val_check_interval.py | 97 ++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index b6cc446cb0840..e571651663606 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -14,6 +14,9 @@ import logging import pytest +import time +import re +from unittest.mock import patch from torch.utils.data import DataLoader from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset @@ -127,9 +130,101 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch(): """Test that an exception is raised when `val_check_interval` is set to float with `check_val_every_n_epoch=None`""" with pytest.raises( - MisconfigurationException, match="`val_check_interval` should be an integer when `check_val_every_n_epoch=None`" + MisconfigurationException, + match=re.escape( + "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', " + "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`." + ) ): Trainer( val_check_interval=0.5, check_val_every_n_epoch=None, ) + +def test_time_based_val_check_interval(tmp_path): + call_count = {"count": 0} + def fake_time(): + result = call_count["count"] + call_count["count"] += 2 + return result + + with patch("time.monotonic", side_effect=fake_time): + trainer = Trainer( + default_root_dir=tmp_path, + logger=False, + enable_checkpointing=False, + max_epochs=1, + max_steps=5, # 5 steps: simulate 10s total wall-clock time + limit_val_batches=1, + val_check_interval="00:00:00:02", # every 2s + ) + model = BoringModel() + trainer.fit(model) + + # Assert 5 validations happened + val_runs = trainer.fit_loop.epoch_loop.val_loop.batch_progress.total.completed + # The number of validation runs should be equal to the number of times we called fake_time + assert val_runs == 5, f"Expected 5 validations, got {val_runs}" + + +@pytest.mark.parametrize( + "check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description", + [ + (None, "00:00:00:04", 2, [0, 1, 0, 1, 0], "val_check_interval timer only, no epoch gating"), + (1, "00:00:00:06", 8, [1, 1, 2, 1, 1], "val_check_interval timer only, no epoch gating"), + (2, "00:00:00:06", 9, [0, 2, 0, 2, 0], "epoch gating, timer longer than epoch"), + (2, "00:00:00:20", 9, [0, 0, 0, 1, 0], "epoch gating, timer much longer"), + (2, "00:00:00:03", 9, [0, 3, 0, 3, 0], "epoch gating, timer shorter than epoch"), + ] +) +def test_time_and_epoch_gated_val_check(tmp_path, check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description): + call_count = {"count": 0} + # Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch) + def fake_time(): + result = call_count["count"] + call_count["count"] += 1 + return result + + # Custom model to record when validation happens (on what epoch) + class TestModel(BoringModel): + val_batches = [] + val_epoch_calls = 0 + + def on_train_batch_end(self, *args, **kwargs): + if isinstance(self.trainer.check_val_every_n_epoch,int) and self.trainer.check_val_every_n_epoch > 1 and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + time.monotonic() + + def on_train_epoch_end(self, *args, **kwargs): + print(trainer.fit_loop.epoch_loop.val_loop.batch_progress.current.completed) + self.val_batches.append(trainer.fit_loop.epoch_loop.val_loop.batch_progress.total.completed) + + def on_validation_epoch_start(self) -> None: + self.val_epoch_calls += 1 + + max_epochs = 5 + max_steps = max_epochs * epoch_duration + limit_train_batches = epoch_duration + + trainer_kwargs = dict( + default_root_dir=tmp_path, + logger=False, + enable_checkpointing=False, + max_epochs=max_epochs, + max_steps=max_steps, + limit_val_batches=1, + limit_train_batches=limit_train_batches, + val_check_interval=val_check_interval, + check_val_every_n_epoch=check_val_every_n_epoch + ) + + with patch("time.monotonic", side_effect=fake_time): + model = TestModel() + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) + + # Validate which epochs validation happened + assert model.val_batches == expected_val_batches, ( + f"\nFAILED: {description}" + f"\nExpected validation at batches: {expected_val_batches}," + f"\nGot: {model.val_batches, model.val_epoch_calls}\n" + ) \ No newline at end of file From 655bb77d8c4b0b34268c9da977b387c3a2775fb6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:41:03 +0000 Subject: [PATCH 04/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/callbacks/model_checkpoint.py | 2 +- .../pytorch/loops/evaluation_loop.py | 4 +- src/lightning/pytorch/loops/fit_loop.py | 4 +- .../pytorch/loops/training_epoch_loop.py | 2 +- .../trainer/connectors/data_connector.py | 2 +- src/lightning/pytorch/trainer/setup.py | 53 ++++++++++--------- src/lightning/pytorch/trainer/trainer.py | 4 +- .../trainer/flags/test_val_check_interval.py | 53 +++++++++++-------- 8 files changed, 67 insertions(+), 57 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 20894d4946154..5f6e0ccb2a703 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -470,7 +470,7 @@ def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: # time-based validation: always defer saving to validation end if getattr(trainer, "_val_check_time_interval", None) is not None: return False - + # if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded # so let's not enforce saving at every training epoch end if trainer.check_val_every_n_epoch != 1: diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index e04fdcc00c581..6036e57cf59ae 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,11 +15,11 @@ import os import shutil import sys +import time from collections import ChainMap, OrderedDict, defaultdict from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import Any, Optional, Union -import time from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -314,7 +314,7 @@ def on_run_end(self) -> list[_OUT_DICT]: if self.verbose and self.trainer.is_global_zero: self._print_results(logged_outputs, self._stage.value) - + now = time.monotonic() self.trainer._last_val_time = now diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index ca6d1a5f54b7b..9f2490d1e86fb 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import time from dataclasses import dataclass from typing import Any, Optional, Union -import time import torch from typing_extensions import override @@ -289,7 +289,7 @@ def setup_data(self) -> None: if getattr(trainer, "_val_check_time_interval", None) is not None: trainer.val_check_batch = None trainer._train_start_time = time.monotonic() - trainer._last_val_time = trainer._train_start_time + trainer._last_val_time = trainer._train_start_time elif isinstance(trainer.val_check_interval, int): trainer.val_check_batch = trainer.val_check_interval if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None: diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index eba200e1b965a..a2eee9d3fa75b 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -13,10 +13,10 @@ # limitations under the License. import contextlib import math +import time from collections import OrderedDict from dataclasses import dataclass from typing import Any, Optional, Union -import time import torch from typing_extensions import override diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index cf8995d3df6e2..240dae6296c1f 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -14,8 +14,8 @@ import os from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Optional, Union from datetime import timedelta +from typing import Any, Optional, Union import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index df0d2037716cd..10d94aaf0d5ec 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -13,8 +13,8 @@ # limitations under the License. """Houses the methods used to set up the Trainer.""" -from typing import Optional, Union from datetime import timedelta +from typing import Optional, Union import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -199,29 +199,30 @@ def _log_device_info(trainer: "pl.Trainer") -> None: if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator): rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.") + def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float: - if isinstance(value, timedelta): - return value.total_seconds() - if isinstance(value, dict): - td = timedelta(**value) - return td.total_seconds() - if isinstance(value, str): - parts = value.split(":") - if len(parts) != 4: - raise MisconfigurationException( - f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'." - ) - d, h, m, s = parts - try: - days = int(d) - hours = int(h) - minutes = int(m) - seconds = int(s) - except ValueError: - raise MisconfigurationException( - f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'." - ) - td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) - return td.total_seconds() - # Should not happen given the caller guards - raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}") \ No newline at end of file + if isinstance(value, timedelta): + return value.total_seconds() + if isinstance(value, dict): + td = timedelta(**value) + return td.total_seconds() + if isinstance(value, str): + parts = value.split(":") + if len(parts) != 4: + raise MisconfigurationException( + f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'." + ) + d, h, m, s = parts + try: + days = int(d) + hours = int(h) + minutes = int(m) + seconds = int(s) + except ValueError: + raise MisconfigurationException( + f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'." + ) + td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) + return td.total_seconds() + # Should not happen given the caller guards + raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}") diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index e6b2e64759ce0..e6618776e4f85 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -109,7 +109,7 @@ def __init__( limit_test_batches: Optional[Union[int, float]] = None, limit_predict_batches: Optional[Union[int, float]] = None, overfit_batches: Union[int, float] = 0.0, - val_check_interval: Optional[Union[int, float, str, timedelta, dict[str,int]]] = None, + val_check_interval: Optional[Union[int, float, str, timedelta, dict[str, int]]] = None, check_val_every_n_epoch: Optional[int] = 1, num_sanity_val_steps: Optional[int] = None, log_every_n_steps: Optional[int] = None, @@ -212,7 +212,7 @@ def __init__( check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` - to be an integer value. When used together with a time-based ``val_check_interval`` and + to be an integer value. When used together with a time-based ``val_check_interval`` and ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index e571651663606..939b8861b1cf7 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -import pytest -import time import re +import time from unittest.mock import patch + +import pytest from torch.utils.data import DataLoader from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset @@ -132,17 +132,19 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch(): with pytest.raises( MisconfigurationException, match=re.escape( - "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', " - "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`." - ) + "`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', " + "datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`." + ), ): Trainer( val_check_interval=0.5, check_val_every_n_epoch=None, ) + def test_time_based_val_check_interval(tmp_path): call_count = {"count": 0} + def fake_time(): result = call_count["count"] call_count["count"] += 2 @@ -168,17 +170,20 @@ def fake_time(): @pytest.mark.parametrize( - "check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description", + ("check_val_every_n_epoch", "val_check_interval", "epoch_duration", "expected_val_batches", "description"), [ (None, "00:00:00:04", 2, [0, 1, 0, 1, 0], "val_check_interval timer only, no epoch gating"), (1, "00:00:00:06", 8, [1, 1, 2, 1, 1], "val_check_interval timer only, no epoch gating"), (2, "00:00:00:06", 9, [0, 2, 0, 2, 0], "epoch gating, timer longer than epoch"), (2, "00:00:00:20", 9, [0, 0, 0, 1, 0], "epoch gating, timer much longer"), (2, "00:00:00:03", 9, [0, 3, 0, 3, 0], "epoch gating, timer shorter than epoch"), - ] + ], ) -def test_time_and_epoch_gated_val_check(tmp_path, check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description): +def test_time_and_epoch_gated_val_check( + tmp_path, check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description +): call_count = {"count": 0} + # Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch) def fake_time(): result = call_count["count"] @@ -191,7 +196,11 @@ class TestModel(BoringModel): val_epoch_calls = 0 def on_train_batch_end(self, *args, **kwargs): - if isinstance(self.trainer.check_val_every_n_epoch,int) and self.trainer.check_val_every_n_epoch > 1 and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + if ( + isinstance(self.trainer.check_val_every_n_epoch, int) + and self.trainer.check_val_every_n_epoch > 1 + and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0 + ): time.monotonic() def on_train_epoch_end(self, *args, **kwargs): @@ -205,17 +214,17 @@ def on_validation_epoch_start(self) -> None: max_steps = max_epochs * epoch_duration limit_train_batches = epoch_duration - trainer_kwargs = dict( - default_root_dir=tmp_path, - logger=False, - enable_checkpointing=False, - max_epochs=max_epochs, - max_steps=max_steps, - limit_val_batches=1, - limit_train_batches=limit_train_batches, - val_check_interval=val_check_interval, - check_val_every_n_epoch=check_val_every_n_epoch - ) + trainer_kwargs = { + "default_root_dir": tmp_path, + "logger": False, + "enable_checkpointing": False, + "max_epochs": max_epochs, + "max_steps": max_steps, + "limit_val_batches": 1, + "limit_train_batches": limit_train_batches, + "val_check_interval": val_check_interval, + "check_val_every_n_epoch": check_val_every_n_epoch, + } with patch("time.monotonic", side_effect=fake_time): model = TestModel() @@ -227,4 +236,4 @@ def on_validation_epoch_start(self) -> None: f"\nFAILED: {description}" f"\nExpected validation at batches: {expected_val_batches}," f"\nGot: {model.val_batches, model.val_epoch_calls}\n" - ) \ No newline at end of file + ) From bbeb0def49f99ceaf5c33d4c6c68a23acbec0853 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 14 Aug 2025 17:56:50 +0500 Subject: [PATCH 05/14] Address pre-commit hook failures and add comment to explain checkpointing logic with time based validation. --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 ++ src/lightning/pytorch/loops/training_epoch_loop.py | 6 ++---- src/lightning/pytorch/trainer/trainer.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 20894d4946154..dfb6bf0fca6bd 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -137,6 +137,8 @@ class ModelCheckpoint(Checkpoint): If ``True``, checkpoints are saved at the end of every training epoch. If ``False``, checkpoints are saved at the end of validation. If ``None`` (default), checkpointing behavior is determined based on training configuration. + If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after + validation. If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of every training epoch. If there are no validation batches of data, checkpointing will occur at the end of the training epoch. If there is a non-default number of validation runs per training epoch diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index eba200e1b965a..a02e3f62915b3 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -538,10 +538,8 @@ def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool: interval = self.trainer._val_check_time_interval if interval is not None: now = time.monotonic() - if now - self.trainer._last_val_time >= interval: - # time’s up → tell Trainer to validate - return True - return False + # if time’s up → tell Trainer to validate + return now - self.trainer._last_val_time >= interval # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = is_last_batch if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index e6b2e64759ce0..c2e9421d7f644 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -216,7 +216,8 @@ def __init__( ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. - For ``None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment. + For ``None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional + alignment. Default: ``1``. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. From 52695b0a223a2d3d6946c27c0efd7be7945ec32f Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Wed, 27 Aug 2025 00:38:47 +0500 Subject: [PATCH 06/14] Fix mypy check failure and val_check_interval initialization in case of time based validation. --- src/lightning/pytorch/trainer/setup.py | 2 -- src/lightning/pytorch/trainer/trainer.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 10d94aaf0d5ec..a01c48a40712f 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -90,8 +90,6 @@ def _init_debugging_flags( trainer._val_check_time_interval = None # default if isinstance(val_check_interval, (str, dict, timedelta)): trainer._val_check_time_interval = _parse_time_interval_seconds(val_check_interval) - # Keep the numeric scheduler neutral; loops should check the time-based attribute. - trainer.val_check_interval = 1.0 else: trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval") diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index d6304e593d59b..80e72168d3ef0 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -498,7 +498,7 @@ def __init__( self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags - self.val_check_batch: Union[int, float] + self.val_check_batch: Optional[Union[int, float]] = None self.val_check_interval: Union[int, float] self.num_sanity_val_steps: Union[int, float] self.limit_train_batches: Union[int, float] From 64eb3c7161e46babe5eeacf62741fb8b911317e1 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sun, 31 Aug 2025 00:29:34 +0500 Subject: [PATCH 07/14] Update docs for time based validation through val_check_interval --- docs/source-pytorch/advanced/speed.rst | 13 ++++++++++- docs/source-pytorch/common/trainer.rst | 28 +++++++++++++++++++++++- src/lightning/pytorch/trainer/trainer.py | 4 ++-- 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/advanced/speed.rst b/docs/source-pytorch/advanced/speed.rst index 53f2938ab099e..93fe9bfcb458d 100644 --- a/docs/source-pytorch/advanced/speed.rst +++ b/docs/source-pytorch/advanced/speed.rst @@ -297,7 +297,8 @@ Validation Within Training Epoch For large datasets, it's often desirable to check validation multiple times within a training epoch. Pass in a float to check that often within one training epoch. Pass in an int ``K`` to check every ``K`` training batch. -Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. +Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. Alternatively, pass a string ("DD:HH:MM:SS"), +a dict of ``datetime.timedelta`` kwargs, or a ``datetime.timedelta`` to check validation after a given amount of wall-clock time. .. testcode:: @@ -310,6 +311,16 @@ Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. # check every 100 train batches (ie: for IterableDatasets or fixed frequency) trainer = Trainer(val_check_interval=100) + # check validation every 15 minutes of wall-clock time + trainer = Trainer(val_check_interval="00:00:15:00") + + # alternatively, pass a dict of timedelta kwargs + trainer = Trainer(val_check_interval={"minutes": 1}) + + # or use a timedelta object directly + from datetime import timedelta + trainer = Trainer(val_check_interval=timedelta(hours=1)) + Learn more in our :ref:`trainer_flags` guide. diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 6a8a8135a1843..9436092623daf 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -989,11 +989,23 @@ val_check_interval :muted: How often within one training epoch to check the validation set. -Can specify as float or int. +Can specify as float, int, or a time-based duration. - pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. - pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training. +- pass a ``string`` duration in the format "DD:HH:MM:SS", a ``datetime.timedelta`` object, or a ``dictionary`` of keyword arguments that can be passed + to ``datetime.timedelta`` for time-based validation. When using a time-based duration, validation will trigger once the elapsed wall-clock time + since the last validation exceeds the interval. The validation check occurs after the current batch completes, the validation loop runs, and + the timer resets. + +**Time-based validation behavior with check_val_every_n_epoch:** When used together with ``val_check_interval`` (time-based) and +``check_val_every_n_epoch > 1``, validation is aligned to epoch multiples: + +- If the time-based interval elapses **before** the next multiple-N epoch, validation runs at the start of that epoch (after the first batch), + and the timer resets. +- If the interval elapses **during** a multiple-N epoch, validation runs after the current batch. +- For cases where ``check_val_every_n_epoch=None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment. .. testcode:: @@ -1011,10 +1023,24 @@ Can specify as float or int. # (ie: production cases with streaming data) trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None) + # check validation every 15 minutes of wall-clock time using a string-based approach + trainer = Trainer(val_check_interval="00:00:15:00") + + # check validation every 15 minutes of wall-clock time using a dictionary-based approach + trainer = Trainer(val_check_interval={"minutes": 15}) + + # check validation every 1 hour of wall-clock time using a dictionary-based approach + trainer = Trainer(val_check_interval={"hours": 1}) + + # check validation every 1 hour of wall-clock time using a datetime.timedelta object + trainer = Trainer(val_check_interval=timedelta(hours=1)) + + .. code-block:: python # Here is the computation to estimate the total number of batches seen within an epoch. + # This logic applies when `val_check_interval` is specified as an integer or a float. # Find the total number of train batches total_train_batches = total_train_samples // (train_batch_size * world_size) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 80e72168d3ef0..179431f2c223b 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -216,8 +216,8 @@ def __init__( ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. - For ``None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional - alignment. + For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without + additional alignment. Default: ``1``. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. From d1411ff4a3f44fa9392e2689fa37da76136dca40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Aug 2025 19:30:12 +0000 Subject: [PATCH 08/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 179431f2c223b..131b1470e0082 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -216,7 +216,7 @@ def __init__( ``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch) and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch. - For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without + For ``None`` or ``1`` cases, the time-based behavior of ``val_check_interval`` applies without additional alignment. Default: ``1``. From 74176b54c1ee9f23bfa6cbeebc77e632d34a4344 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Mon, 1 Sep 2025 23:10:23 +0500 Subject: [PATCH 09/14] Solve mypy and docs import failures --- docs/source-pytorch/common/trainer.rst | 1 + src/lightning/pytorch/loops/training_epoch_loop.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 9436092623daf..fa095a5188e14 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1033,6 +1033,7 @@ Can specify as float, int, or a time-based duration. trainer = Trainer(val_check_interval={"hours": 1}) # check validation every 1 hour of wall-clock time using a datetime.timedelta object + from datetime import timedelta trainer = Trainer(val_check_interval=timedelta(hours=1)) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index d4c7b8a25144d..3d01780b705fe 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -545,6 +545,8 @@ def _should_check_val_fx(self, data_fetcher: _DataFetcher) -> bool: if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: is_val_check_batch = (self.batch_idx + 1) % self.trainer.limit_train_batches == 0 elif self.trainer.val_check_batch != float("inf"): + # if we got here, we’re in batch-based mode, so this can’t be None + assert self.trainer.val_check_batch is not None # if `check_val_every_n_epoch is `None`, run a validation loop every n training batches # else condition it based on the batch_idx of the current epoch current_iteration = self.total_batch_idx if self.trainer.check_val_every_n_epoch is None else self.batch_idx From 469367280ea03b31b8f67f65d4f7a30742feef81 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 3 Sep 2025 06:52:56 +0200 Subject: [PATCH 10/14] changelog --- src/lightning/pytorch/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b3c22a1b1f11f..2f60d0ac96c1d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -22,6 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146)) +- Added time based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071)) + + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) From 1b09699b821d355e5831a831e28a0412d1de4ddc Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 3 Sep 2025 08:08:34 +0200 Subject: [PATCH 11/14] doctest Co-authored-by: Nicki Skafte Detlefsen --- src/lightning/pytorch/trainer/setup.py | 33 ++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index a01c48a40712f..9f3a4845980b6 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -199,6 +199,39 @@ def _log_device_info(trainer: "pl.Trainer") -> None: def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float: + """Convert a time interval into seconds. + + This helper parses different representations of a time interval and + normalizes them into a float number of seconds. + + Supported input formats: + * `timedelta`: The total seconds are returned directly. + * `dict`: A dictionary of keyword arguments accepted by + `datetime.timedelta`, e.g. `{"days": 1, "hours": 2}`. + * `str`: A string in the format `"DD:HH:MM:SS"`, where each + component must be an integer. + + Args: + value (Union[str, timedelta, dict]): The time interval to parse. + + Returns: + float: The duration represented by `value` in seconds. + + Raises: + MisconfigurationException: If the input type is unsupported, the + string format is invalid, or any string component is not an integer. + + Examples: + >>> _parse_time_interval_seconds("01:02:03:04") + 93784.0 + + >>> _parse_time_interval_seconds({"hours": 2, "minutes": 30}) + 9000.0 + + >>> from datetime import timedelta + >>> _parse_time_interval_seconds(timedelta(days=1, seconds=30)) + 86430.0 + """ if isinstance(value, timedelta): return value.total_seconds() if isinstance(value, dict): From 69baf1b7ebb63ee8844bdc661aa7420af3bf276b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Sep 2025 06:09:51 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/trainer/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 9f3a4845980b6..73591d30417b8 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -231,6 +231,7 @@ def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float: >>> from datetime import timedelta >>> _parse_time_interval_seconds(timedelta(days=1, seconds=30)) 86430.0 + """ if isinstance(value, timedelta): return value.total_seconds() From eaa8a232c9df91509099f68c211cfb094499803f Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 4 Sep 2025 10:49:18 +0500 Subject: [PATCH 13/14] Parametrize val_check_interval test to include different types --- .../trainer/flags/test_val_check_interval.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index 939b8861b1cf7..32d46e51a8b83 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -14,6 +14,7 @@ import logging import re import time +from datetime import timedelta from unittest.mock import patch import pytest @@ -141,8 +142,15 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch(): check_val_every_n_epoch=None, ) - -def test_time_based_val_check_interval(tmp_path): +@pytest.mark.parametrize( + "interval", + [ + "00:00:00:02", + {"seconds": 2}, + timedelta(seconds=2), + ], +) +def test_time_based_val_check_interval(tmp_path, interval): call_count = {"count": 0} def fake_time(): @@ -158,7 +166,7 @@ def fake_time(): max_epochs=1, max_steps=5, # 5 steps: simulate 10s total wall-clock time limit_val_batches=1, - val_check_interval="00:00:00:02", # every 2s + val_check_interval=interval, # every 2s ) model = BoringModel() trainer.fit(model) @@ -174,9 +182,9 @@ def fake_time(): [ (None, "00:00:00:04", 2, [0, 1, 0, 1, 0], "val_check_interval timer only, no epoch gating"), (1, "00:00:00:06", 8, [1, 1, 2, 1, 1], "val_check_interval timer only, no epoch gating"), - (2, "00:00:00:06", 9, [0, 2, 0, 2, 0], "epoch gating, timer longer than epoch"), - (2, "00:00:00:20", 9, [0, 0, 0, 1, 0], "epoch gating, timer much longer"), - (2, "00:00:00:03", 9, [0, 3, 0, 3, 0], "epoch gating, timer shorter than epoch"), + (2, "00:00:00:06", 9, [0, 2, 0, 2, 0], "epoch gating, timer shorter than epoch"), + (2, "00:00:00:03", 9, [0, 3, 0, 3, 0], "epoch gating, timer much shorter than epoch"), + (2, "00:00:00:20", 9, [0, 0, 0, 1, 0], "epoch gating, timer longer than epoch"), ], ) def test_time_and_epoch_gated_val_check( From 01c4fef9b5474dbfc0f8d0af57544febd847cb44 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Sep 2025 05:49:46 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/trainer/flags/test_val_check_interval.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index 32d46e51a8b83..53c95e40d2c20 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -142,6 +142,7 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch(): check_val_every_n_epoch=None, ) + @pytest.mark.parametrize( "interval", [