@@ -2124,3 +2124,57 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
2124
2124
2125
2125
# save_last=True should always save last.ckpt
2126
2126
assert (tmp_path / "last.ckpt" ).exists ()
2127
+
2128
+
2129
+ def test_save_last_only_when_checkpoint_saved (tmp_path ):
2130
+ """Test that save_last only creates last.ckpt when another checkpoint is actually saved."""
2131
+
2132
+ class SelectiveModel (BoringModel ):
2133
+ def __init__ (self ):
2134
+ super ().__init__ ()
2135
+ self .validation_step_outputs = []
2136
+
2137
+ def validation_step (self , batch , batch_idx ):
2138
+ outputs = super ().validation_step (batch , batch_idx )
2139
+ epoch = self .trainer .current_epoch
2140
+ loss = torch .tensor (1.0 - epoch * 0.1 ) if epoch % 2 == 0 else torch .tensor (1.0 + epoch * 0.1 )
2141
+ outputs ["val_loss" ] = loss
2142
+ self .validation_step_outputs .append (outputs )
2143
+ return outputs
2144
+
2145
+ def on_validation_epoch_end (self ):
2146
+ if self .validation_step_outputs :
2147
+ avg_loss = torch .stack ([x ["val_loss" ] for x in self .validation_step_outputs ]).mean ()
2148
+ self .log ("val_loss" , avg_loss )
2149
+ self .validation_step_outputs .clear ()
2150
+
2151
+ model = SelectiveModel ()
2152
+
2153
+ checkpoint_callback = ModelCheckpoint (
2154
+ dirpath = tmp_path ,
2155
+ filename = "best-{epoch}-{val_loss:.2f}" ,
2156
+ monitor = "val_loss" ,
2157
+ save_last = True ,
2158
+ save_top_k = 1 ,
2159
+ mode = "min" ,
2160
+ every_n_epochs = 1 ,
2161
+ save_on_train_epoch_end = False ,
2162
+ )
2163
+
2164
+ trainer = Trainer (
2165
+ max_epochs = 4 ,
2166
+ callbacks = [checkpoint_callback ],
2167
+ logger = False ,
2168
+ enable_progress_bar = False ,
2169
+ limit_train_batches = 2 ,
2170
+ limit_val_batches = 2 ,
2171
+ enable_checkpointing = True ,
2172
+ )
2173
+
2174
+ trainer .fit (model )
2175
+
2176
+ checkpoint_files = list (tmp_path .glob ("*.ckpt" ))
2177
+ checkpoint_names = [f .name for f in checkpoint_files ]
2178
+ assert "last.ckpt" in checkpoint_names , "last.ckpt should exist since checkpoints were saved"
2179
+ expected_files = 2 # best checkpoint + last.ckpt
2180
+ assert len (checkpoint_files ) == expected_files , f"Expected { expected_files } files, got { len (checkpoint_files )} : { checkpoint_names } "
0 commit comments