From d094b1c18814f4e2d0dc4ae4845711afbb893eb4 Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Thu, 21 Aug 2025 18:02:18 -0400 Subject: [PATCH 1/7] fix(callbacks): defer step/time-triggered ModelCheckpoint saves until validation metrics are available Root cause: - With `every_n_train_steps` (or `train_time_interval`), checkpoints could save at train batch end before validation ran, so the monitored val metric was missing/stale and `best_model_score` was incorrect. (Refs #20919) Fix: - In [src/lightning/pytorch/callbacks/model_checkpoint.py:ModelCheckpoint.on_train_batch_end]: - Defer saves when the monitored key is missing from [trainer.callback_metrics] - If on the last train batch and not saving at train-epoch-end, defer only when validation will run next: - `trainer.enable_validation` is True - `trainer.num_val_batches` > 0 - `trainer.check_val_every_n_epoch` schedule matches the upcoming epoch - Perform deferred saves in [on_validation_end], ensuring fresh validation metrics are used. - Allow zero `timedelta` for `train_time_interval` and broadcast the time-trigger decision across ranks. - Do not defer when monitoring a train metric or when no validation is scheduled. Tests: - Repro (previously failing, now passing): - [tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py] - Additional validations: - [tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py] - [tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py] Outcome: - `best_model_score` matches the validation metric after the epoch. - Step/time-interval checkpointing behaves correctly without premature or skipped saves. --- .../pytorch/callbacks/model_checkpoint.py | 56 ++++- .../test_model_checkpoint_additional_cases.py | 205 ++++++++++++++++++ .../test_model_checkpoint_edge_cases.py | 172 +++++++++++++++ ...del_checkpoint_step_interval_val_metric.py | 105 +++++++++ 4 files changed, 535 insertions(+), 3 deletions(-) create mode 100644 tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py create mode 100644 tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py create mode 100644 tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 68fed2ff82d31..452e8bdecbba3 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -260,6 +260,9 @@ def __init__( self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" + # When using step/time-based checkpointing with a validation-only monitored metric, + # defer the save until validation has produced the metric + self._defer_save_until_validation: bool = False self.kth_value: Tensor self.dirpath: Optional[_PATH] @@ -306,14 +309,17 @@ def on_train_batch_end( batch_idx: int, ) -> None: """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" - if self._should_skip_saving_checkpoint(trainer): - return + # Do not return early here because we may need to set deferral flags even + # if a save already happened at this global step. We'll enforce the skip + # just before actually saving below. + skip_due_to_state = self._should_skip_saving_checkpoint(trainer) skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) train_time_interval = self._train_time_interval skip_time = True now = time.monotonic() - if train_time_interval: + # Important: allow zero timedelta as a valid interval + if train_time_interval is not None: prev_time_check = self._last_time_checked skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() # in case we have time differences across ranks @@ -326,6 +332,42 @@ def on_train_batch_end( self._last_time_checked = now monitor_candidates = self._monitor_candidates(trainer) + # If monitoring a metric that is not yet available (e.g., validation-only), + # defer saving until validation end so the metric is present. + if self.monitor is not None and self.monitor not in monitor_candidates: + # Defer both top-k and last to avoid blocking with `_last_global_step_saved` + self._defer_save_until_validation = True + return + + # Even if the monitored key exists, it could be stale from a previous validation. + # If validation is scheduled to run right after this batch (e.g., last batch of epoch) + # and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics. + if ( + self.monitor is not None + and not self._should_save_on_train_epoch_end(trainer) + and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False) + ): + # Only defer if a validation loop is expected to run after this batch. + will_run_val = False + if getattr(trainer, "enable_validation", False): + num_val_batches = ( + sum(trainer.num_val_batches) + if isinstance(trainer.num_val_batches, list) + else trainer.num_val_batches + ) + if num_val_batches and num_val_batches > 0: + cve = trainer.check_val_every_n_epoch + if cve is None or ((trainer.current_epoch + 1) % cve == 0): + will_run_val = True + + if will_run_val: + self._defer_save_until_validation = True + return + + # Only proceed to save if not skipping due to trainer/callback state + if skip_due_to_state: + return + self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) @@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul """Save a checkpoint at the end of the validation stage.""" if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) + # If a step/time-triggered save was deferred due to a missing monitored metric, + # perform the save now that validation metrics are available. + if self._defer_save_until_validation: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + self._defer_save_until_validation = False + return + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py new file mode 100644 index 0000000000000..712575bec1a26 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py @@ -0,0 +1,205 @@ +import math +from datetime import timedelta + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 4): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +class TrainMetricModule(LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 1) + self._counter = 0.0 + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + # strictly increasing train metric per step + self._counter += 1.0 + self.log("train_score", torch.tensor(self._counter), on_step=True, on_epoch=False, prog_bar=False, logger=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +def _make_loaders(n=4): + ds = TinyDataset(n=n) + train_loader = DataLoader(ds, batch_size=2, shuffle=False) + val_loader = DataLoader(ds, batch_size=2, shuffle=False) + return train_loader, val_loader + + +def test_model_checkpoint_every_n_train_steps_with_train_metric_saves_at_step(tmp_path): + """When monitoring a train-step metric, step-interval checkpointing should save at the step boundary (no deferral) + and best_model_score should match the last train metric value.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + model = TrainMetricModule() + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="train_score", + mode="max", + save_top_k=1, + every_n_train_steps=1, + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + # 2 batches/epoch, run 2 epochs to have multiple step saves + trainer = Trainer( + max_epochs=2, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=0, # no validation needed for this test + enable_checkpointing=True, + enable_model_summary=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + # 2 epochs * 2 steps/epoch = 4 steps total; metric increments by 1 each step + expected = 4.0 + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +@pytest.mark.parametrize("val_scores", [[0.2, 0.4, 0.9]]) +def test_model_checkpoint_time_interval_with_val_metric_defers_until_validation(tmp_path, val_scores): + """With time-interval-based checkpointing, and a validation-only metric, ensure we don't save using stale metrics + at step boundaries; saving should occur at validation end.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=0, # disable step-based + train_time_interval=timedelta(seconds=0), # trigger as often as possible + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +class ValMetricModule(LightningModule): + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + score = self._val_scores[self.current_epoch] + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0, 3.0]]) +def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tmp_path, val_scores): + """With validation running every 2 epochs, step-triggered saves at the end of non-validation epochs should be + deferred and then performed at the next validation end when the metric is available.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=2, # end of each epoch + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + check_val_every_n_epoch=2, # only validate every 2 epochs + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) # last/maximum value occurs at final validation epoch + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py new file mode 100644 index 0000000000000..d294757d1ef76 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py @@ -0,0 +1,172 @@ +import math +from datetime import timedelta + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 8): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +def _make_loaders(n=8, batch_size=2): + ds = TinyDataset(n=n) + train_loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + val_loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + return train_loader, val_loader + + +class MultiValPerEpochModule(LightningModule): + """Logs a validation metric on every validation run, even if validation is run multiple times per epoch.""" + + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + self._val_call_idx = 0 + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + score = self._val_scores[self._val_call_idx] + self._val_call_idx += 1 + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +class ValOnceEveryTwoEpochsModule(LightningModule): + """Logs a validation metric only when validation runs (e.g., every 2 epochs), indexed by current_epoch.""" + + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + # current_epoch indexes into provided scores; only called when validation runs + score = self._val_scores[self.current_epoch] + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.9]]) +def test_checkpoint_defers_with_mid_epoch_validation(tmp_path, val_scores): + """With val_check_interval=0.5 (validation mid-epoch and at epoch end), and step-based checkpointing, saves must be + deferred until each validation end so monitored validation metrics are fresh.""" + seed_everything(123) + + # 4 train batches per epoch (batch_size=2 over n=8), so two validations: after 2 batches and after 4 batches + train_loader, val_loader = _make_loaders(n=8, batch_size=2) + + model = MultiValPerEpochModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=1, # would trigger every step, but must defer to validation + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=1, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=4, # ensure exactly 4 steps => two validations at 0.5 and 1.0 + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + val_check_interval=0.5, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +@pytest.mark.parametrize("val_scores", [[0.2, 0.6]]) +def test_time_interval_defers_across_epoch_until_first_validation(tmp_path, val_scores): + """With time-interval saving and validation only every 2 epochs, ensure no save uses stale/missing validation + metrics; the first save should happen at the first validation end (epoch 2).""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4, batch_size=2) + + model = ValOnceEveryTwoEpochsModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=0, # disable step-based + train_time_interval=timedelta(seconds=0), # trigger frequently + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=2, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + check_val_every_n_epoch=2, # first validation only after 2nd epoch + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = val_scores[1] # validation runs only once at epoch 2, logging index 1 + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py new file mode 100644 index 0000000000000..2922a7597810b --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py @@ -0,0 +1,105 @@ +import math + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 4): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +class ValMetricModule(LightningModule): + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + # LightningModule API (minimal) + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + # do nothing per-step; we log at epoch end + pass + + def on_validation_epoch_end(self): + # Log a validation metric only at validation epoch end + # Values increase across epochs; best should be the last epoch + score = self._val_scores[self.current_epoch] + # use logger=True so it lands in trainer.callback_metrics + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0]]) +def test_model_checkpoint_every_n_train_steps_with_val_metric_saves_after_val(tmp_path, val_scores): + """Reproduces #20919: Using every_n_train_steps with a validation-only metric should save the best checkpoint only + after the metric is computed at validation, not earlier at the train-step boundary. + + Expectation: best_model_score equals the last (max) val score. + + """ + seed_everything(123) + + # 2 train batches per epoch (so checkpoint triggers at the epoch boundary) + ds = TinyDataset(n=4) + train_loader = DataLoader(ds, batch_size=2, shuffle=False) + val_loader = DataLoader(ds, batch_size=2, shuffle=False) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + # critical: trigger on train steps, not on epoch end + every_n_train_steps=2, # equal to number of train batches per epoch + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + # Should equal the last (max) validation score + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6), ( + f"best_model_score should be {expected} (last/maximum val score), got {actual}.\n" + f"This indicates the checkpoint was saved before the validation metric was computed." + ) From b88b546afdcb6f02526e5ba472bee8034642e6e9 Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Thu, 21 Aug 2025 19:14:04 -0400 Subject: [PATCH 2/7] test: disable logger in model checkpoint tests to avoid side effects --- .../callbacks/test_model_checkpoint_additional_cases.py | 3 +++ .../callbacks/test_model_checkpoint_edge_cases.py | 2 ++ .../test_model_checkpoint_step_interval_val_metric.py | 1 + 3 files changed, 6 insertions(+) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py index 712575bec1a26..fba9e865debfd 100644 --- a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py @@ -84,6 +84,7 @@ def test_model_checkpoint_every_n_train_steps_with_train_metric_saves_at_step(tm limit_val_batches=0, # no validation needed for this test enable_checkpointing=True, enable_model_summary=False, + logger=False, ) trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) @@ -128,6 +129,7 @@ def test_model_checkpoint_time_interval_with_val_metric_defers_until_validation( limit_val_batches=1, enable_checkpointing=True, enable_model_summary=False, + logger=False, ) trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) @@ -194,6 +196,7 @@ def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tm limit_val_batches=1, enable_checkpointing=True, enable_model_summary=False, + logger=False, check_val_every_n_epoch=2, # only validate every 2 epochs ) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py index d294757d1ef76..a265f8bc5f194 100644 --- a/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py @@ -117,6 +117,7 @@ def test_checkpoint_defers_with_mid_epoch_validation(tmp_path, val_scores): limit_val_batches=1, enable_checkpointing=True, enable_model_summary=False, + logger=False, val_check_interval=0.5, ) @@ -161,6 +162,7 @@ def test_time_interval_defers_across_epoch_until_first_validation(tmp_path, val_ limit_val_batches=1, enable_checkpointing=True, enable_model_summary=False, + logger=False, check_val_every_n_epoch=2, # first validation only after 2nd epoch ) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py index 2922a7597810b..c3fa0bfcd2e38 100644 --- a/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py @@ -91,6 +91,7 @@ def test_model_checkpoint_every_n_train_steps_with_val_metric_saves_after_val(tm limit_val_batches=1, enable_checkpointing=True, enable_model_summary=False, + logger=False, ) trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) From 59dda0258f97ea06ccd343a6a1a665d4ae2d147e Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Thu, 21 Aug 2025 21:59:04 -0400 Subject: [PATCH 3/7] refactor: defer DeepSpeed import and logging configuration until needed --- src/lightning/fabric/strategies/deepspeed.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index c11ae8589d1ff..13013a3a666ec 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -282,10 +282,10 @@ def __init__( sub_group_size=sub_group_size, ) - import deepspeed - self._config_initialized = False - deepspeed.utils.logging.logger.setLevel(logging_level) + # Defer importing and configuring DeepSpeed logging until it is actually needed. + # Store the desired logging level to be applied on first use. + self._logging_level = logging_level self.remote_device = remote_device self.load_full_weights = load_full_weights @@ -374,6 +374,8 @@ def module_sharded_context(self) -> AbstractContextManager: import deepspeed + deepspeed.utils.logging.logger.setLevel(self._logging_level) + assert self._config_initialized return deepspeed.zero.Init( enabled=self.zero_stage_3, @@ -601,6 +603,8 @@ def _initialize_engine( """ import deepspeed + deepspeed.utils.logging.logger.setLevel(self._logging_level) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( args=argparse.Namespace(device_rank=self.root_device.index), @@ -628,7 +632,11 @@ def _setup_distributed(self) -> None: _validate_device_index_selection(self.parallel_devices) reset_seed() self._set_world_ranks() - self._init_deepspeed_distributed() + # Avoid initializing DeepSpeed distributed for single-process runs. This also avoids importing + # DeepSpeed in environments where it may not be fully functional (e.g., missing nvcc), + # while still allowing configuration and dataloader setup logic to run. + if self.world_size > 1: + self._init_deepspeed_distributed() if not self._config_initialized: self._format_config() self._config_initialized = True @@ -636,6 +644,8 @@ def _setup_distributed(self) -> None: def _init_deepspeed_distributed(self) -> None: import deepspeed + deepspeed.utils.logging.logger.setLevel(self._logging_level) + assert self.cluster_environment is not None if platform.system() != "Windows": # do not set env variables on windows, allow deepspeed to control setup @@ -661,6 +671,8 @@ def _set_node_environment_variables(self) -> None: def _set_deepspeed_activation_checkpointing(self) -> None: import deepspeed + deepspeed.utils.logging.logger.setLevel(self._logging_level) + assert isinstance(self.config, dict) if self.config.get("activation_checkpointing"): checkpoint_config = self.config["activation_checkpointing"] From 6c1554a631dfadf2274caae0e5755e6ef06cf11c Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Thu, 21 Aug 2025 22:48:01 -0400 Subject: [PATCH 4/7] test: add mock-based CPU tests for DeepSpeed strategy import paths --- .../strategies/test_deepspeed_imports_mock.py | 229 ++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 tests/tests_fabric/strategies/test_deepspeed_imports_mock.py diff --git a/tests/tests_fabric/strategies/test_deepspeed_imports_mock.py b/tests/tests_fabric/strategies/test_deepspeed_imports_mock.py new file mode 100644 index 0000000000000..9d9dd4d3f0873 --- /dev/null +++ b/tests/tests_fabric/strategies/test_deepspeed_imports_mock.py @@ -0,0 +1,229 @@ +# Copyright The Lightning AI team. +# This test file provides CPU-only coverage for DeepSpeed lazy-import paths by mocking a minimal +# `deepspeed` module. It does not require GPUs or the real DeepSpeed package. + +import sys +from types import ModuleType +from unittest.mock import Mock + +import pytest + +from lightning.fabric.strategies import DeepSpeedStrategy + + +class _FakeLogger: + def __init__(self): + self.levels = [] + + def setLevel(self, lvl): + self.levels.append(lvl) + + +class _FakeZeroInit: + def __init__(self, *args, **kwargs): + # record for assertions + self.args = args + self.kwargs = kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +@pytest.fixture +def fake_deepspeed(monkeypatch): + """Inject a minimal fake `deepspeed` package into sys.modules.""" + ds = ModuleType("deepspeed") + # Mark as a package with a spec and path so importlib won't complain + import importlib.machinery + + ds.__spec__ = importlib.machinery.ModuleSpec("deepspeed", loader=Mock(), is_package=True) + ds.__path__ = [] # type: ignore[attr-defined] + + # utils.logging.logger + utils_mod = ModuleType("deepspeed.utils") + logging_mod = ModuleType("deepspeed.utils.logging") + utils_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils", loader=Mock(), is_package=True) + logging_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils.logging", loader=Mock(), is_package=False) + logger = _FakeLogger() + logging_mod.logger = logger + utils_mod.logging = logging_mod + ds.utils = utils_mod + + # zero.Init + zero_mod = ModuleType("deepspeed.zero") + zero_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.zero", loader=Mock(), is_package=False) + zero_mod.Init = _FakeZeroInit + ds.zero = zero_mod + + # checkpointing.configure + checkpointing_mod = ModuleType("deepspeed.checkpointing") + checkpointing_mod.__spec__ = importlib.machinery.ModuleSpec( + "deepspeed.checkpointing", loader=Mock(), is_package=False + ) + recorded = {"configure_calls": []} + + def _configure(**kwargs): + recorded["configure_calls"].append(kwargs) + + checkpointing_mod.configure = _configure + ds.checkpointing = checkpointing_mod + + # initialize + recorded["initialize_calls"] = [] + + def _initialize(**kwargs): + recorded["initialize_calls"].append(kwargs) + # return values: (engine, optimizer, _, scheduler) + return Mock(name="engine"), Mock(name="optimizer"), None, Mock(name="scheduler") + + ds.initialize = _initialize + + # init_distributed + recorded["init_distributed_calls"] = [] + + def _init_distributed(*args, **kwargs): + recorded["init_distributed_calls"].append((args, kwargs)) + + ds.init_distributed = _init_distributed + + # install into sys.modules + monkeypatch.setitem(sys.modules, "deepspeed", ds) + monkeypatch.setitem(sys.modules, "deepspeed.utils", utils_mod) + monkeypatch.setitem(sys.modules, "deepspeed.utils.logging", logging_mod) + monkeypatch.setitem(sys.modules, "deepspeed.zero", zero_mod) + monkeypatch.setitem(sys.modules, "deepspeed.checkpointing", checkpointing_mod) + + # Pretend deepspeed is installed by forcing availability flag to True + monkeypatch.setattr("lightning.fabric.strategies.deepspeed._DEEPSPEED_AVAILABLE", True, raising=False) + + return ds, logger, recorded + + +def _make_strategy_with_defaults(): + # Use defaults; we'll tweak attributes per test as needed + return DeepSpeedStrategy() + + +def _get_backend() -> str: + # simple helper used to override strategy._get_process_group_backend + return "gloo" + + +def test_module_sharded_context_sets_logger_and_returns_zero_init(fake_deepspeed): + ds_mod, logger, recorded = fake_deepspeed + + strategy = _make_strategy_with_defaults() + # The context asserts that the config was initialized + strategy._config_initialized = True # type: ignore[attr-defined] + + ctx = strategy.module_sharded_context() + assert isinstance(ctx, _FakeZeroInit) + # logger.setLevel should be called at least once + assert len(logger.levels) >= 1 + + +def test_initialize_engine_import_and_logger_and_call(fake_deepspeed): + ds_mod, logger, recorded = fake_deepspeed + + strategy = _make_strategy_with_defaults() + # root_device.index is read; use a CUDA device number even on CPU-only hosts (no allocation happens) + import torch + + strategy.parallel_devices = [torch.device("cuda", 0)] # type: ignore[attr-defined] + + class _Param: + requires_grad = True + + model = Mock() + model.parameters.return_value = [_Param()] + + engine, optimizer, scheduler = strategy._initialize_engine(model) + + # assertions + assert len(logger.levels) >= 1 + assert recorded["initialize_calls"], "deepspeed.initialize was not called" + call = recorded["initialize_calls"][0] + assert call["config"] == strategy.config + assert call["model"] is model + assert call["dist_init_required"] is False + # returned mocks are propagated + from unittest.mock import Mock as _M + + assert isinstance(engine, _M) + assert engine._mock_name == "engine" + assert isinstance(optimizer, _M) + assert optimizer._mock_name == "optimizer" + assert isinstance(scheduler, _M) + assert scheduler._mock_name == "scheduler" + + +def test_init_deepspeed_distributed_calls_import_and_init(fake_deepspeed, monkeypatch): + ds_mod, logger, recorded = fake_deepspeed + + strategy = _make_strategy_with_defaults() + + # minimal cluster env + class _CE: + main_port = 12345 + main_address = "127.0.0.1" + + def global_rank(self): + return 0 + + def local_rank(self): + return 0 + + def node_rank(self): + return 0 + + def world_size(self): + return 1 + + def teardown(self): + pass + + strategy.cluster_environment = _CE() + strategy._process_group_backend = "gloo" # avoid CUDA requirement + strategy._timeout = 300 # type: ignore[attr-defined] + + strategy._get_process_group_backend = _get_backend # type: ignore[assignment] + + # ensure non-Windows path + monkeypatch.setattr("platform.system", lambda: "Linux") + + strategy._init_deepspeed_distributed() + + assert len(logger.levels) >= 1 + assert recorded["init_distributed_calls"], "deepspeed.init_distributed was not called" + args, kwargs = recorded["init_distributed_calls"][0] + assert args[0] == "gloo" + assert kwargs["distributed_port"] == 12345 + assert "timeout" in kwargs + + +def test_set_deepspeed_activation_checkpointing_configured(fake_deepspeed): + ds_mod, logger, recorded = fake_deepspeed + + strategy = _make_strategy_with_defaults() + # ensure config contains activation_checkpointing keys + assert isinstance(strategy.config, dict) + strategy.config.setdefault("activation_checkpointing", {}) + strategy.config["activation_checkpointing"].update({ + "partition_activations": True, + "contiguous_memory_optimization": False, + "cpu_checkpointing": True, + "profile": False, + }) + + strategy._set_deepspeed_activation_checkpointing() + + assert len(logger.levels) >= 1 + assert recorded["configure_calls"], "deepspeed.checkpointing.configure was not called" + cfg = recorded["configure_calls"][0] + assert cfg["partition_activations"] is True + assert cfg["contiguous_checkpointing"] is False + assert cfg["checkpoint_in_cpu"] is True + assert cfg["profile"] is False From ef816b62a67ff3ab656c8edf80cc4bbede903625 Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Thu, 28 Aug 2025 19:30:00 -0400 Subject: [PATCH 5/7] Revert "refactor: defer DeepSpeed import and logging configuration until needed" This reverts commit 59dda0258f97ea06ccd343a6a1a665d4ae2d147e. --- src/lightning/fabric/strategies/deepspeed.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 13013a3a666ec..c11ae8589d1ff 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -282,10 +282,10 @@ def __init__( sub_group_size=sub_group_size, ) + import deepspeed + self._config_initialized = False - # Defer importing and configuring DeepSpeed logging until it is actually needed. - # Store the desired logging level to be applied on first use. - self._logging_level = logging_level + deepspeed.utils.logging.logger.setLevel(logging_level) self.remote_device = remote_device self.load_full_weights = load_full_weights @@ -374,8 +374,6 @@ def module_sharded_context(self) -> AbstractContextManager: import deepspeed - deepspeed.utils.logging.logger.setLevel(self._logging_level) - assert self._config_initialized return deepspeed.zero.Init( enabled=self.zero_stage_3, @@ -603,8 +601,6 @@ def _initialize_engine( """ import deepspeed - deepspeed.utils.logging.logger.setLevel(self._logging_level) - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize( args=argparse.Namespace(device_rank=self.root_device.index), @@ -632,11 +628,7 @@ def _setup_distributed(self) -> None: _validate_device_index_selection(self.parallel_devices) reset_seed() self._set_world_ranks() - # Avoid initializing DeepSpeed distributed for single-process runs. This also avoids importing - # DeepSpeed in environments where it may not be fully functional (e.g., missing nvcc), - # while still allowing configuration and dataloader setup logic to run. - if self.world_size > 1: - self._init_deepspeed_distributed() + self._init_deepspeed_distributed() if not self._config_initialized: self._format_config() self._config_initialized = True @@ -644,8 +636,6 @@ def _setup_distributed(self) -> None: def _init_deepspeed_distributed(self) -> None: import deepspeed - deepspeed.utils.logging.logger.setLevel(self._logging_level) - assert self.cluster_environment is not None if platform.system() != "Windows": # do not set env variables on windows, allow deepspeed to control setup @@ -671,8 +661,6 @@ def _set_node_environment_variables(self) -> None: def _set_deepspeed_activation_checkpointing(self) -> None: import deepspeed - deepspeed.utils.logging.logger.setLevel(self._logging_level) - assert isinstance(self.config, dict) if self.config.get("activation_checkpointing"): checkpoint_config = self.config["activation_checkpointing"] From 836de5a751b9e5f67b389edea2d5b8dee663d4d1 Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Thu, 28 Aug 2025 19:30:00 -0400 Subject: [PATCH 6/7] Revert "test: add mock-based CPU tests for DeepSpeed strategy import paths" This reverts commit 6c1554a631dfadf2274caae0e5755e6ef06cf11c. --- .../strategies/test_deepspeed_imports_mock.py | 229 ------------------ 1 file changed, 229 deletions(-) delete mode 100644 tests/tests_fabric/strategies/test_deepspeed_imports_mock.py diff --git a/tests/tests_fabric/strategies/test_deepspeed_imports_mock.py b/tests/tests_fabric/strategies/test_deepspeed_imports_mock.py deleted file mode 100644 index 9d9dd4d3f0873..0000000000000 --- a/tests/tests_fabric/strategies/test_deepspeed_imports_mock.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright The Lightning AI team. -# This test file provides CPU-only coverage for DeepSpeed lazy-import paths by mocking a minimal -# `deepspeed` module. It does not require GPUs or the real DeepSpeed package. - -import sys -from types import ModuleType -from unittest.mock import Mock - -import pytest - -from lightning.fabric.strategies import DeepSpeedStrategy - - -class _FakeLogger: - def __init__(self): - self.levels = [] - - def setLevel(self, lvl): - self.levels.append(lvl) - - -class _FakeZeroInit: - def __init__(self, *args, **kwargs): - # record for assertions - self.args = args - self.kwargs = kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -@pytest.fixture -def fake_deepspeed(monkeypatch): - """Inject a minimal fake `deepspeed` package into sys.modules.""" - ds = ModuleType("deepspeed") - # Mark as a package with a spec and path so importlib won't complain - import importlib.machinery - - ds.__spec__ = importlib.machinery.ModuleSpec("deepspeed", loader=Mock(), is_package=True) - ds.__path__ = [] # type: ignore[attr-defined] - - # utils.logging.logger - utils_mod = ModuleType("deepspeed.utils") - logging_mod = ModuleType("deepspeed.utils.logging") - utils_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils", loader=Mock(), is_package=True) - logging_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils.logging", loader=Mock(), is_package=False) - logger = _FakeLogger() - logging_mod.logger = logger - utils_mod.logging = logging_mod - ds.utils = utils_mod - - # zero.Init - zero_mod = ModuleType("deepspeed.zero") - zero_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.zero", loader=Mock(), is_package=False) - zero_mod.Init = _FakeZeroInit - ds.zero = zero_mod - - # checkpointing.configure - checkpointing_mod = ModuleType("deepspeed.checkpointing") - checkpointing_mod.__spec__ = importlib.machinery.ModuleSpec( - "deepspeed.checkpointing", loader=Mock(), is_package=False - ) - recorded = {"configure_calls": []} - - def _configure(**kwargs): - recorded["configure_calls"].append(kwargs) - - checkpointing_mod.configure = _configure - ds.checkpointing = checkpointing_mod - - # initialize - recorded["initialize_calls"] = [] - - def _initialize(**kwargs): - recorded["initialize_calls"].append(kwargs) - # return values: (engine, optimizer, _, scheduler) - return Mock(name="engine"), Mock(name="optimizer"), None, Mock(name="scheduler") - - ds.initialize = _initialize - - # init_distributed - recorded["init_distributed_calls"] = [] - - def _init_distributed(*args, **kwargs): - recorded["init_distributed_calls"].append((args, kwargs)) - - ds.init_distributed = _init_distributed - - # install into sys.modules - monkeypatch.setitem(sys.modules, "deepspeed", ds) - monkeypatch.setitem(sys.modules, "deepspeed.utils", utils_mod) - monkeypatch.setitem(sys.modules, "deepspeed.utils.logging", logging_mod) - monkeypatch.setitem(sys.modules, "deepspeed.zero", zero_mod) - monkeypatch.setitem(sys.modules, "deepspeed.checkpointing", checkpointing_mod) - - # Pretend deepspeed is installed by forcing availability flag to True - monkeypatch.setattr("lightning.fabric.strategies.deepspeed._DEEPSPEED_AVAILABLE", True, raising=False) - - return ds, logger, recorded - - -def _make_strategy_with_defaults(): - # Use defaults; we'll tweak attributes per test as needed - return DeepSpeedStrategy() - - -def _get_backend() -> str: - # simple helper used to override strategy._get_process_group_backend - return "gloo" - - -def test_module_sharded_context_sets_logger_and_returns_zero_init(fake_deepspeed): - ds_mod, logger, recorded = fake_deepspeed - - strategy = _make_strategy_with_defaults() - # The context asserts that the config was initialized - strategy._config_initialized = True # type: ignore[attr-defined] - - ctx = strategy.module_sharded_context() - assert isinstance(ctx, _FakeZeroInit) - # logger.setLevel should be called at least once - assert len(logger.levels) >= 1 - - -def test_initialize_engine_import_and_logger_and_call(fake_deepspeed): - ds_mod, logger, recorded = fake_deepspeed - - strategy = _make_strategy_with_defaults() - # root_device.index is read; use a CUDA device number even on CPU-only hosts (no allocation happens) - import torch - - strategy.parallel_devices = [torch.device("cuda", 0)] # type: ignore[attr-defined] - - class _Param: - requires_grad = True - - model = Mock() - model.parameters.return_value = [_Param()] - - engine, optimizer, scheduler = strategy._initialize_engine(model) - - # assertions - assert len(logger.levels) >= 1 - assert recorded["initialize_calls"], "deepspeed.initialize was not called" - call = recorded["initialize_calls"][0] - assert call["config"] == strategy.config - assert call["model"] is model - assert call["dist_init_required"] is False - # returned mocks are propagated - from unittest.mock import Mock as _M - - assert isinstance(engine, _M) - assert engine._mock_name == "engine" - assert isinstance(optimizer, _M) - assert optimizer._mock_name == "optimizer" - assert isinstance(scheduler, _M) - assert scheduler._mock_name == "scheduler" - - -def test_init_deepspeed_distributed_calls_import_and_init(fake_deepspeed, monkeypatch): - ds_mod, logger, recorded = fake_deepspeed - - strategy = _make_strategy_with_defaults() - - # minimal cluster env - class _CE: - main_port = 12345 - main_address = "127.0.0.1" - - def global_rank(self): - return 0 - - def local_rank(self): - return 0 - - def node_rank(self): - return 0 - - def world_size(self): - return 1 - - def teardown(self): - pass - - strategy.cluster_environment = _CE() - strategy._process_group_backend = "gloo" # avoid CUDA requirement - strategy._timeout = 300 # type: ignore[attr-defined] - - strategy._get_process_group_backend = _get_backend # type: ignore[assignment] - - # ensure non-Windows path - monkeypatch.setattr("platform.system", lambda: "Linux") - - strategy._init_deepspeed_distributed() - - assert len(logger.levels) >= 1 - assert recorded["init_distributed_calls"], "deepspeed.init_distributed was not called" - args, kwargs = recorded["init_distributed_calls"][0] - assert args[0] == "gloo" - assert kwargs["distributed_port"] == 12345 - assert "timeout" in kwargs - - -def test_set_deepspeed_activation_checkpointing_configured(fake_deepspeed): - ds_mod, logger, recorded = fake_deepspeed - - strategy = _make_strategy_with_defaults() - # ensure config contains activation_checkpointing keys - assert isinstance(strategy.config, dict) - strategy.config.setdefault("activation_checkpointing", {}) - strategy.config["activation_checkpointing"].update({ - "partition_activations": True, - "contiguous_memory_optimization": False, - "cpu_checkpointing": True, - "profile": False, - }) - - strategy._set_deepspeed_activation_checkpointing() - - assert len(logger.levels) >= 1 - assert recorded["configure_calls"], "deepspeed.checkpointing.configure was not called" - cfg = recorded["configure_calls"][0] - assert cfg["partition_activations"] is True - assert cfg["contiguous_checkpointing"] is False - assert cfg["checkpoint_in_cpu"] is True - assert cfg["profile"] is False From a2a5964c201158ff5e7ead4d17a4a51977fc49ca Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 29 Aug 2025 18:30:32 +0200 Subject: [PATCH 7/7] chlog --- src/lightning/pytorch/CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fe9173d008230..176e34273d776 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106)) + ---