Skip to content

Conversation

littlebullGit
Copy link
Contributor

@littlebullGit littlebullGit commented Aug 21, 2025

Defer step/time-triggered ModelCheckpoint saves until validation metrics are available

Fixes #20919

Root cause

  • With every_n_train_steps (or train_time_interval), checkpoints could save at train-batch end before validation ran. The monitored validation metric was missing/stale, so best_model_score could be incorrect.

Fix

  • In [src/lightning/pytorch/callbacks/model_checkpoint.py]:
    • [ModelCheckpoint.on_train_batch_end]:
      • Defer saves when the monitored key is missing from [trainer.callback_metrics].
      • If at the last train batch and not saving at train-epoch-end, defer only when validation will run next:
        • trainer.enable_validation is True
        • trainer.num_val_batches > 0
        • trainer.check_val_every_n_epoch schedule matches the upcoming epoch
    • [ModelCheckpoint.on_validation_end]:
      • Perform deferred saves to use fresh validation metrics.
    • Allow zero timedelta for train_time_interval and broadcast the time-trigger decision across ranks via trainer.strategy.broadcast.
    • No deferral when monitoring a train metric or when validation won’t run.

Tests

  • Repro (previously failing, now passing):
    • [tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py]
  • Additional validations:
    • [tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py]
    • [tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py]

Outcome

  • best_model_score matches the latest validation metric.
  • Step/time-interval checkpointing behaves correctly without premature or skipped saves.

📚 Documentation preview 📚: https://pytorch-lightning--21106.org.readthedocs.build/en/21106/

@github-actions github-actions bot added pl Generic label for PyTorch Lightning package fabric lightning.fabric.Fabric labels Aug 21, 2025
Copy link

codecov bot commented Aug 22, 2025

Codecov Report

❌ Patch coverage is 90.32258% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 87%. Comparing base (e55650d) to head (6c1554a).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21106   +/-   ##
=======================================
  Coverage      87%      87%           
=======================================
  Files         269      269           
  Lines       23520    23545   +25     
=======================================
+ Hits        20508    20542   +34     
+ Misses       3012     3003    -9     

@littlebullGit
Copy link
Contributor Author

@Borda , I cannot see the error in the two failed jobs. Can you help me or point me to what the error is ?
View more details on Lit OSS [bot]
You don't have access to this Studio

@Borda
Copy link
Contributor

Borda commented Aug 22, 2025

I cannot see the error in the two failed jobs. Can you help me or point me to what the error is ? View more details on Lit OSS [bot] You don't have access to this Studio

these tests are optional for now as the same are failing also on master

… validation metrics are available

Root cause:
- With `every_n_train_steps` (or `train_time_interval`), checkpoints could save at train batch end before validation ran, so the monitored val metric was missing/stale and `best_model_score` was incorrect. (Refs Lightning-AI#20919)

Fix:
- In [src/lightning/pytorch/callbacks/model_checkpoint.py:ModelCheckpoint.on_train_batch_end]:
  - Defer saves when the monitored key is missing from [trainer.callback_metrics]
  - If on the last train batch and not saving at train-epoch-end, defer only when validation will run next:
    - `trainer.enable_validation` is True
    - `trainer.num_val_batches` > 0
    - `trainer.check_val_every_n_epoch` schedule matches the upcoming epoch
- Perform deferred saves in [on_validation_end], ensuring fresh validation metrics are used.
- Allow zero `timedelta` for `train_time_interval` and broadcast the time-trigger decision across ranks.
- Do not defer when monitoring a train metric or when no validation is scheduled.

Tests:
- Repro (previously failing, now passing):
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py]
- Additional validations:
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py]
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py]

Outcome:
- `best_model_score` matches the validation metric after the epoch.
- Step/time-interval checkpointing behaves correctly without premature or skipped saves.
@littlebullGit littlebullGit force-pushed the fix/20919-checkpoint-step-val-metric branch from 094b278 to 6c1554a Compare August 27, 2025 22:10
@Borda Borda merged commit b1cc925 into Lightning-AI:master Aug 29, 2025
88 of 91 checks passed
Borda added a commit that referenced this pull request Sep 3, 2025
… validation metrics are available (#21106)

* fix(callbacks): defer step/time-triggered ModelCheckpoint saves until validation metrics are available

Root cause:
- With `every_n_train_steps` (or `train_time_interval`), checkpoints could save at train batch end before validation ran, so the monitored val metric was missing/stale and `best_model_score` was incorrect. (Refs #20919)

Fix:
- In [src/lightning/pytorch/callbacks/model_checkpoint.py:ModelCheckpoint.on_train_batch_end]:
  - Defer saves when the monitored key is missing from [trainer.callback_metrics]
  - If on the last train batch and not saving at train-epoch-end, defer only when validation will run next:
    - `trainer.enable_validation` is True
    - `trainer.num_val_batches` > 0
    - `trainer.check_val_every_n_epoch` schedule matches the upcoming epoch
- Perform deferred saves in [on_validation_end], ensuring fresh validation metrics are used.
- Allow zero `timedelta` for `train_time_interval` and broadcast the time-trigger decision across ranks.
- Do not defer when monitoring a train metric or when no validation is scheduled.

Tests:
- Repro (previously failing, now passing):
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py]
- Additional validations:
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py]
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py]

Outcome:
- `best_model_score` matches the validation metric after the epoch.
- Step/time-interval checkpointing behaves correctly without premature or skipped saves.

* test: disable logger in model checkpoint tests to avoid side effects

* chlog

---------

Co-authored-by: Jirka B <[email protected]>
(cherry picked from commit b1cc925)
lantiga pushed a commit that referenced this pull request Sep 5, 2025
… validation metrics are available (#21106)

* fix(callbacks): defer step/time-triggered ModelCheckpoint saves until validation metrics are available

Root cause:
- With `every_n_train_steps` (or `train_time_interval`), checkpoints could save at train batch end before validation ran, so the monitored val metric was missing/stale and `best_model_score` was incorrect. (Refs #20919)

Fix:
- In [src/lightning/pytorch/callbacks/model_checkpoint.py:ModelCheckpoint.on_train_batch_end]:
  - Defer saves when the monitored key is missing from [trainer.callback_metrics]
  - If on the last train batch and not saving at train-epoch-end, defer only when validation will run next:
    - `trainer.enable_validation` is True
    - `trainer.num_val_batches` > 0
    - `trainer.check_val_every_n_epoch` schedule matches the upcoming epoch
- Perform deferred saves in [on_validation_end], ensuring fresh validation metrics are used.
- Allow zero `timedelta` for `train_time_interval` and broadcast the time-trigger decision across ranks.
- Do not defer when monitoring a train metric or when no validation is scheduled.

Tests:
- Repro (previously failing, now passing):
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py]
- Additional validations:
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py]
  - [tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py]

Outcome:
- `best_model_score` matches the validation metric after the epoch.
- Step/time-interval checkpointing behaves correctly without premature or skipped saves.

* test: disable logger in model checkpoint tests to avoid side effects

* chlog

---------

Co-authored-by: Jirka B <[email protected]>
(cherry picked from commit b1cc925)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

When checkpointing with a step interval on a validation metric, the checkpointing is done before the validation computationstep

3 participants