Skip to content

Commit 169e20c

Browse files
authored
Merge branch 'master' into fsdp-grad-clip-by-norm
2 parents dee2225 + 1fc077b commit 169e20c

File tree

24 files changed

+482
-65
lines changed

24 files changed

+482
-65
lines changed

docs/source-pytorch/advanced/speed.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ Validation Within Training Epoch
297297

298298
For large datasets, it's often desirable to check validation multiple times within a training epoch.
299299
Pass in a float to check that often within one training epoch. Pass in an int ``K`` to check every ``K`` training batch.
300-
Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`.
300+
Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`. Alternatively, pass a string ("DD:HH:MM:SS"),
301+
a dict of ``datetime.timedelta`` kwargs, or a ``datetime.timedelta`` to check validation after a given amount of wall-clock time.
301302

302303
.. testcode::
303304

@@ -310,6 +311,16 @@ Must use an ``int`` if using an :class:`~torch.utils.data.IterableDataset`.
310311
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
311312
trainer = Trainer(val_check_interval=100)
312313

314+
# check validation every 15 minutes of wall-clock time
315+
trainer = Trainer(val_check_interval="00:00:15:00")
316+
317+
# alternatively, pass a dict of timedelta kwargs
318+
trainer = Trainer(val_check_interval={"minutes": 1})
319+
320+
# or use a timedelta object directly
321+
from datetime import timedelta
322+
trainer = Trainer(val_check_interval=timedelta(hours=1))
323+
313324
Learn more in our :ref:`trainer_flags` guide.
314325

315326

docs/source-pytorch/common/trainer.rst

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,11 +991,23 @@ val_check_interval
991991
:muted:
992992

993993
How often within one training epoch to check the validation set.
994-
Can specify as float or int.
994+
Can specify as float, int, or a time-based duration.
995995

996996
- pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch.
997997
- pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training
998998
batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or iteration-based training.
999+
- pass a ``string`` duration in the format "DD:HH:MM:SS", a ``datetime.timedelta`` object, or a ``dictionary`` of keyword arguments that can be passed
1000+
to ``datetime.timedelta`` for time-based validation. When using a time-based duration, validation will trigger once the elapsed wall-clock time
1001+
since the last validation exceeds the interval. The validation check occurs after the current batch completes, the validation loop runs, and
1002+
the timer resets.
1003+
1004+
**Time-based validation behavior with check_val_every_n_epoch:** When used together with ``val_check_interval`` (time-based) and
1005+
``check_val_every_n_epoch > 1``, validation is aligned to epoch multiples:
1006+
1007+
- If the time-based interval elapses **before** the next multiple-N epoch, validation runs at the start of that epoch (after the first batch),
1008+
and the timer resets.
1009+
- If the interval elapses **during** a multiple-N epoch, validation runs after the current batch.
1010+
- For cases where ``check_val_every_n_epoch=None`` or ``1``, the time-based behavior of ``val_check_interval`` applies without additional alignment.
9991011

10001012
.. testcode::
10011013

@@ -1013,10 +1025,25 @@ Can specify as float or int.
10131025
# (ie: production cases with streaming data)
10141026
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
10151027

1028+
# check validation every 15 minutes of wall-clock time using a string-based approach
1029+
trainer = Trainer(val_check_interval="00:00:15:00")
1030+
1031+
# check validation every 15 minutes of wall-clock time using a dictionary-based approach
1032+
trainer = Trainer(val_check_interval={"minutes": 15})
1033+
1034+
# check validation every 1 hour of wall-clock time using a dictionary-based approach
1035+
trainer = Trainer(val_check_interval={"hours": 1})
1036+
1037+
# check validation every 1 hour of wall-clock time using a datetime.timedelta object
1038+
from datetime import timedelta
1039+
trainer = Trainer(val_check_interval=timedelta(hours=1))
1040+
1041+
10161042

10171043
.. code-block:: python
10181044
10191045
# Here is the computation to estimate the total number of batches seen within an epoch.
1046+
# This logic applies when `val_check_interval` is specified as an integer or a float.
10201047
10211048
# Find the total number of train batches
10221049
total_train_batches = total_train_samples // (train_batch_size * world_size)

src/lightning/fabric/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
### Changed
2424

25-
-
25+
- let `_get_default_process_group_backend_for_device` support more hardware platforms (
26+
[#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))
2627

2728

2829
### Fixed
2930

3031
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
3132

3233

34+
- Respect `verbose=False` in `seed_everything` when no seed is provided
35+
36+
3337
---
3438

3539
## [2.5.4] - 2025-08-29

src/lightning/fabric/strategies/ddp.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,17 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
160160
if torch.distributed.get_backend() == "nccl":
161161
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
162162
else:
163-
torch.distributed.barrier()
163+
# Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error
164+
try:
165+
torch.distributed.barrier()
166+
except RuntimeError as e:
167+
if "PrivateUse1HooksInterface" in str(e):
168+
# Fallback: Use all_reduce as barrier - all processes must participate
169+
# This achieves the same synchronization effect as barrier()
170+
dummy_tensor = torch.tensor(0.0, device=self.root_device)
171+
torch.distributed.all_reduce(dummy_tensor)
172+
else:
173+
raise
164174

165175
@override
166176
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:

src/lightning/fabric/utilities/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
return "nccl" if device.type == "cuda" else "gloo"
322+
"""Return corresponding distributed backend for a given device."""
323+
device_backend_map = torch.distributed.Backend.default_device_backend_map
324+
if device.type in device_backend_map:
325+
return device_backend_map[device.type]
326+
return "gloo"
323327

324328

325329
class _DatasetSamplerWrapper(Dataset):

src/lightning/fabric/utilities/seed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
4040
env_seed = os.environ.get("PL_GLOBAL_SEED")
4141
if env_seed is None:
4242
seed = 0
43-
rank_zero_warn(f"No seed found, seed set to {seed}")
43+
if verbose:
44+
rank_zero_warn(f"No seed found, seed set to {seed}")
4445
else:
4546
try:
4647
seed = int(env_seed)

src/lightning/pytorch/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))
2626

2727

28+
- Added time based validation support though `val_check_interval` ([#21071](https://github.com/Lightning-AI/pytorch-lightning/pull/21071))
29+
30+
2831
### Changed
2932

3033
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
@@ -48,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4851

4952
- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
5053

54+
55+
- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162))
56+
5157
---
5258

5359
## [2.5.4] - 2025-08-29

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class ModelCheckpoint(Checkpoint):
137137
If ``True``, checkpoints are saved at the end of every training epoch.
138138
If ``False``, checkpoints are saved at the end of validation.
139139
If ``None`` (default), checkpointing behavior is determined based on training configuration.
140+
If ``val_check_interval`` is a str, dict, or `timedelta` (time-based), checkpointing is performed after
141+
validation.
140142
If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of
141143
every training epoch. If there are no validation batches of data, checkpointing will occur at the
142144
end of the training epoch. If there is a non-default number of validation runs per training epoch
@@ -517,6 +519,10 @@ def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
517519
if self._save_on_train_epoch_end is not None:
518520
return self._save_on_train_epoch_end
519521

522+
# time-based validation: always defer saving to validation end
523+
if getattr(trainer, "_val_check_time_interval", None) is not None:
524+
return False
525+
520526
# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded
521527
# so let's not enforce saving at every training epoch end
522528
if trainer.check_val_every_n_epoch != 1:

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import shutil
1717
import sys
18+
import time
1819
from collections import ChainMap, OrderedDict, defaultdict
1920
from collections.abc import Iterable, Iterator
2021
from dataclasses import dataclass
@@ -314,6 +315,9 @@ def on_run_end(self) -> list[_OUT_DICT]:
314315
if self.verbose and self.trainer.is_global_zero:
315316
self._print_results(logged_outputs, self._stage.value)
316317

318+
now = time.monotonic()
319+
self.trainer._last_val_time = now
320+
317321
return logged_outputs
318322

319323
def teardown(self) -> None:

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import time
1516
from dataclasses import dataclass
1617
from typing import Any, Optional, Union
1718

@@ -283,7 +284,13 @@ def setup_data(self) -> None:
283284
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
284285
self._last_train_dl_reload_epoch = trainer.current_epoch
285286

286-
if isinstance(trainer.val_check_interval, int):
287+
# If time-based validation is enabled, disable batch-based scheduling here.
288+
# Use None to clearly signal "no batch-based validation"; wall-time logic will run elsewhere.
289+
if getattr(trainer, "_val_check_time_interval", None) is not None:
290+
trainer.val_check_batch = None
291+
trainer._train_start_time = time.monotonic()
292+
trainer._last_val_time = trainer._train_start_time
293+
elif isinstance(trainer.val_check_interval, int):
287294
trainer.val_check_batch = trainer.val_check_interval
288295
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
289296
raise ValueError(
@@ -299,7 +306,7 @@ def setup_data(self) -> None:
299306
else:
300307
raise MisconfigurationException(
301308
"When using an IterableDataset for `train_dataloader`,"
302-
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
309+
" `Trainer(val_check_interval)` must be time based, `1.0` or an int. An int k specifies"
303310
" checking validation every k training batches."
304311
)
305312
else:

0 commit comments

Comments
 (0)