Skip to content

Commit 52b3215

Browse files
authored
Merge branch 'master' into weights-only-compatibility
2 parents 9aadbef + b1cc925 commit 52b3215

File tree

8 files changed

+571
-18
lines changed

8 files changed

+571
-18
lines changed

.lightning/workflows/fabric.yml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@ trigger:
44
pull_request:
55
branches: ["master"]
66

7-
timeout: "75" # minutes
8-
machine: "L4_X_2"
7+
timeout: "55" # minutes
98
parametrize:
109
matrix: {}
1110
include:
12-
# note that this is setting also all oldest requirements which is linked to Torch == 2.0
11+
# note that this is setting also all oldest requirements which is linked to Torch == 2.1
1312
- image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1"
1413
PACKAGE_NAME: "fabric"
15-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
14+
machine: "A100_X_2"
15+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
1616
PACKAGE_NAME: "fabric"
17+
machine: "L4_X_2"
1718
# - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
1819
# PACKAGE_NAME: "fabric"
19-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
20+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
2021
PACKAGE_NAME: "lightning"
22+
machine: "L4_X_2"
2123
exclude: []
2224

2325
env:
@@ -30,6 +32,7 @@ run: |
3032
python --version
3133
pip --version
3234
pip install -q fire wget packaging
35+
pip list
3336
set -ex
3437
3538
CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda"
@@ -40,12 +43,15 @@ run: |
4043
echo "Torch URL: ${TORCH_URL}"
4144
COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
4245
echo "collecting coverage for: ${COVERAGE_SOURCE}"
46+
TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])")
4347
4448
if [ "${TORCH_VER}" == "2.1" ]; then
4549
echo "Set oldest versions"
46-
cd requirements/fabric
50+
pip uninstall -y deepspeed
4751
pip install -U "lightning-utilities[cli]"
52+
cd requirements/fabric
4853
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']"
54+
python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt
4955
cd ../..
5056
pip install "cython<3.0" wheel # for compatibility
5157
fi
@@ -92,6 +98,7 @@ run: |
9298
export PL_RUN_STANDALONE_TESTS=1
9399
wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh
94100
bash ./run_standalone_tests.sh "tests_fabric"
101+
export PL_RUN_STANDALONE_TESTS=0
95102
96103
# echo "Reporting coverage" # todo
97104
# python -m coverage report

.lightning/workflows/pytorch.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@ trigger:
44
pull_request:
55
branches: ["master"]
66

7-
timeout: "75" # minutes
8-
machine: "L4_X_2"
7+
timeout: "55" # minutes
98
parametrize:
109
matrix: {}
1110
include:
12-
# note that this is setting also all oldest requirements which is linked to Torch == 2.0
11+
# note that this is setting also all oldest requirements which is linked to Torch == 2.1
1312
- image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1"
1413
PACKAGE_NAME: "pytorch"
15-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
14+
machine: "A100_X_2"
15+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
1616
PACKAGE_NAME: "pytorch"
17+
machine: "L4_X_2"
1718
# - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
1819
# PACKAGE_NAME: "pytorch"
19-
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
20+
- image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
2021
PACKAGE_NAME: "lightning"
22+
machine: "L4_X_2"
2123
exclude: []
2224

2325
env:
@@ -30,6 +32,7 @@ run: |
3032
python --version
3133
pip --version
3234
pip install -q fire wget packaging
35+
pip list
3336
set -ex
3437
3538
CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda"
@@ -40,12 +43,15 @@ run: |
4043
echo "Torch URL: ${TORCH_URL}"
4144
COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="pytorch_lightning").get(n, n))')
4245
echo "collecting coverage for: ${COVERAGE_SOURCE}"
46+
TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])")
4347
4448
if [ "${TORCH_VER}" == "2.1" ]; then
45-
recho "Set oldest versions"
46-
cd requirements/pytorch
49+
echo "Set oldest versions"
50+
pip uninstall -y deepspeed
4751
pip install -U "lightning-utilities[cli]"
52+
cd requirements/pytorch
4853
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']"
54+
python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt
4955
cd ../..
5056
pip install "cython<3.0" wheel # for compatibility
5157
fi
@@ -108,6 +114,7 @@ run: |
108114
export PL_RUN_STANDALONE_TESTS=1
109115
wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh
110116
bash ./run_standalone_tests.sh "tests_pytorch"
117+
export PL_RUN_STANDALONE_TESTS=0
111118
112119
echo "Testing: PyTorch standalone tasks"
113120
cd tests_pytorch/

dockers/base-cuda/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
ARG UBUNTU_VERSION=22.04
16-
ARG CUDA_VERSION=11.7.1
16+
ARG CUDA_VERSION=12.1.1
1717

1818

1919
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131

3232
### Fixed
3333

34-
-
34+
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))
35+
3536

3637

3738
---

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ def __init__(
260260
self.best_model_path = ""
261261
self.last_model_path = ""
262262
self._last_checkpoint_saved = ""
263+
# When using step/time-based checkpointing with a validation-only monitored metric,
264+
# defer the save until validation has produced the metric
265+
self._defer_save_until_validation: bool = False
263266

264267
self.kth_value: Tensor
265268
self.dirpath: Optional[_PATH]
@@ -306,14 +309,17 @@ def on_train_batch_end(
306309
batch_idx: int,
307310
) -> None:
308311
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
309-
if self._should_skip_saving_checkpoint(trainer):
310-
return
312+
# Do not return early here because we may need to set deferral flags even
313+
# if a save already happened at this global step. We'll enforce the skip
314+
# just before actually saving below.
315+
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
311316
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
312317

313318
train_time_interval = self._train_time_interval
314319
skip_time = True
315320
now = time.monotonic()
316-
if train_time_interval:
321+
# Important: allow zero timedelta as a valid interval
322+
if train_time_interval is not None:
317323
prev_time_check = self._last_time_checked
318324
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
319325
# in case we have time differences across ranks
@@ -326,6 +332,42 @@ def on_train_batch_end(
326332
self._last_time_checked = now
327333

328334
monitor_candidates = self._monitor_candidates(trainer)
335+
# If monitoring a metric that is not yet available (e.g., validation-only),
336+
# defer saving until validation end so the metric is present.
337+
if self.monitor is not None and self.monitor not in monitor_candidates:
338+
# Defer both top-k and last to avoid blocking with `_last_global_step_saved`
339+
self._defer_save_until_validation = True
340+
return
341+
342+
# Even if the monitored key exists, it could be stale from a previous validation.
343+
# If validation is scheduled to run right after this batch (e.g., last batch of epoch)
344+
# and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
345+
if (
346+
self.monitor is not None
347+
and not self._should_save_on_train_epoch_end(trainer)
348+
and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False)
349+
):
350+
# Only defer if a validation loop is expected to run after this batch.
351+
will_run_val = False
352+
if getattr(trainer, "enable_validation", False):
353+
num_val_batches = (
354+
sum(trainer.num_val_batches)
355+
if isinstance(trainer.num_val_batches, list)
356+
else trainer.num_val_batches
357+
)
358+
if num_val_batches and num_val_batches > 0:
359+
cve = trainer.check_val_every_n_epoch
360+
if cve is None or ((trainer.current_epoch + 1) % cve == 0):
361+
will_run_val = True
362+
363+
if will_run_val:
364+
self._defer_save_until_validation = True
365+
return
366+
367+
# Only proceed to save if not skipping due to trainer/callback state
368+
if skip_due_to_state:
369+
return
370+
329371
self._save_topk_checkpoint(trainer, monitor_candidates)
330372
self._save_last_checkpoint(trainer, monitor_candidates)
331373

@@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
343385
"""Save a checkpoint at the end of the validation stage."""
344386
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
345387
monitor_candidates = self._monitor_candidates(trainer)
388+
# If a step/time-triggered save was deferred due to a missing monitored metric,
389+
# perform the save now that validation metrics are available.
390+
if self._defer_save_until_validation:
391+
self._save_topk_checkpoint(trainer, monitor_candidates)
392+
self._save_last_checkpoint(trainer, monitor_candidates)
393+
self._defer_save_until_validation = False
394+
return
395+
346396
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
347397
self._save_topk_checkpoint(trainer, monitor_candidates)
348398
self._save_last_checkpoint(trainer, monitor_candidates)

0 commit comments

Comments
 (0)