Skip to content

Commit 727ce57

Browse files
Merge branch 'master' into feature/13324_validation-interval
2 parents d1411ff + b1cc925 commit 727ce57

File tree

9 files changed

+571
-19
lines changed

9 files changed

+571
-19
lines changed

.lightning/workflows/fabric.yml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@ trigger:
44
pull_request:
55
branches: ["master"]
66

7-
timeout: "75" # minutes
8-
machine: "L4_X_2"
7+
timeout: "55" # minutes
98
parametrize:
109
matrix: {}
1110
include:
12-
# note that this is setting also all oldest requirements which is linked to Torch == 2.0
11+
# note that this is setting also all oldest requirements which is linked to Torch == 2.1
1312
- image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1"
1413
PACKAGE_NAME: "fabric"
15-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
14+
machine: "A100_X_2"
15+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
1616
PACKAGE_NAME: "fabric"
17+
machine: "L4_X_2"
1718
# - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
1819
# PACKAGE_NAME: "fabric"
19-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
20+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
2021
PACKAGE_NAME: "lightning"
22+
machine: "L4_X_2"
2123
exclude: []
2224

2325
env:
@@ -30,6 +32,7 @@ run: |
3032
python --version
3133
pip --version
3234
pip install -q fire wget packaging
35+
pip list
3336
set -ex
3437
3538
CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda"
@@ -40,12 +43,15 @@ run: |
4043
echo "Torch URL: ${TORCH_URL}"
4144
COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
4245
echo "collecting coverage for: ${COVERAGE_SOURCE}"
46+
TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])")
4347
4448
if [ "${TORCH_VER}" == "2.1" ]; then
4549
echo "Set oldest versions"
46-
cd requirements/fabric
50+
pip uninstall -y deepspeed
4751
pip install -U "lightning-utilities[cli]"
52+
cd requirements/fabric
4853
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']"
54+
python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt
4955
cd ../..
5056
pip install "cython<3.0" wheel # for compatibility
5157
fi
@@ -92,6 +98,7 @@ run: |
9298
export PL_RUN_STANDALONE_TESTS=1
9399
wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh
94100
bash ./run_standalone_tests.sh "tests_fabric"
101+
export PL_RUN_STANDALONE_TESTS=0
95102
96103
# echo "Reporting coverage" # todo
97104
# python -m coverage report

.lightning/workflows/pytorch.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@ trigger:
44
pull_request:
55
branches: ["master"]
66

7-
timeout: "75" # minutes
8-
machine: "L4_X_2"
7+
timeout: "55" # minutes
98
parametrize:
109
matrix: {}
1110
include:
12-
# note that this is setting also all oldest requirements which is linked to Torch == 2.0
11+
# note that this is setting also all oldest requirements which is linked to Torch == 2.1
1312
- image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1"
1413
PACKAGE_NAME: "pytorch"
15-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
14+
machine: "A100_X_2"
15+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
1616
PACKAGE_NAME: "pytorch"
17+
machine: "L4_X_2"
1718
# - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
1819
# PACKAGE_NAME: "pytorch"
19-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
20+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
2021
PACKAGE_NAME: "lightning"
22+
machine: "L4_X_2"
2123
exclude: []
2224

2325
env:
@@ -30,6 +32,7 @@ run: |
3032
python --version
3133
pip --version
3234
pip install -q fire wget packaging
35+
pip list
3336
set -ex
3437
3538
CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda"
@@ -40,12 +43,15 @@ run: |
4043
echo "Torch URL: ${TORCH_URL}"
4144
COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="pytorch_lightning").get(n, n))')
4245
echo "collecting coverage for: ${COVERAGE_SOURCE}"
46+
TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])")
4347
4448
if [ "${TORCH_VER}" == "2.1" ]; then
45-
recho "Set oldest versions"
46-
cd requirements/pytorch
49+
echo "Set oldest versions"
50+
pip uninstall -y deepspeed
4751
pip install -U "lightning-utilities[cli]"
52+
cd requirements/pytorch
4853
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']"
54+
python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt
4955
cd ../..
5056
pip install "cython<3.0" wheel # for compatibility
5157
fi
@@ -108,6 +114,7 @@ run: |
108114
export PL_RUN_STANDALONE_TESTS=1
109115
wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh
110116
bash ./run_standalone_tests.sh "tests_pytorch"
117+
export PL_RUN_STANDALONE_TESTS=0
111118
112119
echo "Testing: PyTorch standalone tasks"
113120
cd tests_pytorch/

dockers/base-cuda/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
ARG UBUNTU_VERSION=22.04
16-
ARG CUDA_VERSION=11.7.1
16+
ARG CUDA_VERSION=12.1.1
1717

1818

1919
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}

requirements/pytorch/test.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in
1919

2020
tensorboard >=2.11, <2.21.0 # for `TensorBoardLogger`
2121

22-
--find-links https://download.pytorch.org/whl/torch-tensorrt
2322
torch-tensorrt; platform_system == "Linux" and python_version >= "3.12"

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
### Fixed
3030

31-
-
31+
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))
32+
3233

3334

3435
---

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ def __init__(
262262
self.best_model_path = ""
263263
self.last_model_path = ""
264264
self._last_checkpoint_saved = ""
265+
# When using step/time-based checkpointing with a validation-only monitored metric,
266+
# defer the save until validation has produced the metric
267+
self._defer_save_until_validation: bool = False
265268

266269
self.kth_value: Tensor
267270
self.dirpath: Optional[_PATH]
@@ -308,14 +311,17 @@ def on_train_batch_end(
308311
batch_idx: int,
309312
) -> None:
310313
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
311-
if self._should_skip_saving_checkpoint(trainer):
312-
return
314+
# Do not return early here because we may need to set deferral flags even
315+
# if a save already happened at this global step. We'll enforce the skip
316+
# just before actually saving below.
317+
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
313318
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
314319

315320
train_time_interval = self._train_time_interval
316321
skip_time = True
317322
now = time.monotonic()
318-
if train_time_interval:
323+
# Important: allow zero timedelta as a valid interval
324+
if train_time_interval is not None:
319325
prev_time_check = self._last_time_checked
320326
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
321327
# in case we have time differences across ranks
@@ -328,6 +334,42 @@ def on_train_batch_end(
328334
self._last_time_checked = now
329335

330336
monitor_candidates = self._monitor_candidates(trainer)
337+
# If monitoring a metric that is not yet available (e.g., validation-only),
338+
# defer saving until validation end so the metric is present.
339+
if self.monitor is not None and self.monitor not in monitor_candidates:
340+
# Defer both top-k and last to avoid blocking with `_last_global_step_saved`
341+
self._defer_save_until_validation = True
342+
return
343+
344+
# Even if the monitored key exists, it could be stale from a previous validation.
345+
# If validation is scheduled to run right after this batch (e.g., last batch of epoch)
346+
# and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
347+
if (
348+
self.monitor is not None
349+
and not self._should_save_on_train_epoch_end(trainer)
350+
and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False)
351+
):
352+
# Only defer if a validation loop is expected to run after this batch.
353+
will_run_val = False
354+
if getattr(trainer, "enable_validation", False):
355+
num_val_batches = (
356+
sum(trainer.num_val_batches)
357+
if isinstance(trainer.num_val_batches, list)
358+
else trainer.num_val_batches
359+
)
360+
if num_val_batches and num_val_batches > 0:
361+
cve = trainer.check_val_every_n_epoch
362+
if cve is None or ((trainer.current_epoch + 1) % cve == 0):
363+
will_run_val = True
364+
365+
if will_run_val:
366+
self._defer_save_until_validation = True
367+
return
368+
369+
# Only proceed to save if not skipping due to trainer/callback state
370+
if skip_due_to_state:
371+
return
372+
331373
self._save_topk_checkpoint(trainer, monitor_candidates)
332374
self._save_last_checkpoint(trainer, monitor_candidates)
333375

@@ -345,6 +387,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
345387
"""Save a checkpoint at the end of the validation stage."""
346388
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
347389
monitor_candidates = self._monitor_candidates(trainer)
390+
# If a step/time-triggered save was deferred due to a missing monitored metric,
391+
# perform the save now that validation metrics are available.
392+
if self._defer_save_until_validation:
393+
self._save_topk_checkpoint(trainer, monitor_candidates)
394+
self._save_last_checkpoint(trainer, monitor_candidates)
395+
self._defer_save_until_validation = False
396+
return
397+
348398
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
349399
self._save_topk_checkpoint(trainer, monitor_candidates)
350400
self._save_last_checkpoint(trainer, monitor_candidates)

0 commit comments

Comments
 (0)