@@ -380,7 +380,10 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
380
380
monitor_candidates = self ._monitor_candidates (trainer )
381
381
if self ._every_n_epochs >= 1 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0 :
382
382
self ._save_topk_checkpoint (trainer , monitor_candidates )
383
- self ._save_last_checkpoint (trainer , monitor_candidates )
383
+ # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link"
384
+ if (self ._last_global_step_saved == trainer .global_step or
385
+ (self .save_last == "link" and self ._last_checkpoint_saved )):
386
+ self ._save_last_checkpoint (trainer , monitor_candidates )
384
387
385
388
@override
386
389
def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
@@ -397,7 +400,10 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
397
400
398
401
if self ._every_n_epochs >= 1 and (trainer .current_epoch + 1 ) % self ._every_n_epochs == 0 :
399
402
self ._save_topk_checkpoint (trainer , monitor_candidates )
400
- self ._save_last_checkpoint (trainer , monitor_candidates )
403
+ # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link"
404
+ if (self ._last_global_step_saved == trainer .global_step or
405
+ (self .save_last == "link" and self ._last_checkpoint_saved )):
406
+ self ._save_last_checkpoint (trainer , monitor_candidates )
401
407
402
408
@override
403
409
def on_exception (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , exception : BaseException ) -> None :
@@ -902,3 +908,5 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren
902
908
def _remove_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
903
909
"""Calls the strategy to remove the checkpoint file."""
904
910
trainer .strategy .remove_checkpoint (filepath )
911
+
912
+
0 commit comments