Skip to content

Commit 02477d5

Browse files
committed
model checkpoint on eception split trainer setup over two lines
1 parent 9f6063b commit 02477d5

File tree

1 file changed

+34
-22
lines changed

1 file changed

+34
-22
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,8 @@ def training_step(self, batch, batch_idx):
773773

774774
model = TroubledModel()
775775
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
776-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False, enable_progress_bar=False)
776+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback],
777+
max_epochs=5, logger=False, enable_progress_bar=False)
777778
with pytest.raises(RuntimeError, match="Trouble!"):
778779
trainer.fit(model)
779780
print(os.listdir(tmp_path))
@@ -785,15 +786,16 @@ class TroubledModel(BoringModel):
785786
def validation_step(self, batch, batch_idx):
786787
if not trainer.sanity_checking and batch_idx == 0:
787788
raise RuntimeError("Trouble!")
788-
789+
789790
model = TroubledModel()
790791
epoch_length = 64
791792
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
792-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback], max_epochs=5, logger=False, enable_progress_bar=False)
793+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback],
794+
max_epochs=5, logger=False, enable_progress_bar=False)
793795
with pytest.raises(RuntimeError, match="Trouble!"):
794796
trainer.fit(model)
795797
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
796-
798+
797799

798800
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path):
799801
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
@@ -804,7 +806,8 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
804806

805807
model = BoringModel()
806808
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
807-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], max_epochs=5, logger=False, enable_progress_bar=False)
809+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()],
810+
max_epochs=5, logger=False, enable_progress_bar=False)
808811
with pytest.raises(RuntimeError, match="Trouble!"):
809812
trainer.fit(model)
810813
assert os.path.isfile(tmp_path / "step=1.ckpt")
@@ -816,10 +819,11 @@ class TroublemakerOnTrainBatchEnd(Callback):
816819
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
817820
if batch_idx == 1:
818821
raise RuntimeError("Trouble!")
819-
822+
820823
model = BoringModel()
821824
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
822-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], max_epochs=5, logger=False, enable_progress_bar=False)
825+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()],
826+
max_epochs=5, logger=False, enable_progress_bar=False)
823827
with pytest.raises(RuntimeError, match="Trouble!"):
824828
trainer.fit(model)
825829

@@ -832,11 +836,12 @@ class TroublemakerOnTrainEpochStart(Callback):
832836
def on_train_epoch_start(self, trainer, pl_module):
833837
if trainer.current_epoch == 1:
834838
raise RuntimeError("Trouble!")
835-
839+
836840
model = BoringModel()
837841
epoch_length = 64
838842
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
839-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], max_epochs=5, logger=False, enable_progress_bar=False)
843+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()],
844+
max_epochs=5, logger=False, enable_progress_bar=False)
840845
with pytest.raises(RuntimeError, match="Trouble!"):
841846
trainer.fit(model)
842847
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
@@ -852,7 +857,8 @@ def on_train_epoch_end(self, trainer, pl_module):
852857
model = BoringModel()
853858
epoch_length = 64
854859
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
855-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], max_epochs=5, logger=False, enable_progress_bar=False)
860+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()],
861+
max_epochs=5, logger=False, enable_progress_bar=False)
856862
with pytest.raises(RuntimeError, match="Trouble!"):
857863
trainer.fit(model)
858864
assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt")
@@ -868,7 +874,8 @@ def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx):
868874
model = BoringModel()
869875
epoch_length = 64
870876
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
871-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()], max_epochs=5, logger=False, enable_progress_bar=False)
877+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchStart()],
878+
max_epochs=5, logger=False, enable_progress_bar=False)
872879
with pytest.raises(RuntimeError, match="Trouble!"):
873880
trainer.fit(model)
874881
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
@@ -880,11 +887,12 @@ class TroublemakerOnValidationBatchEnd(Callback):
880887
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
881888
if not trainer.sanity_checking and batch_idx == 1:
882889
raise RuntimeError("Trouble!")
883-
890+
884891
model = BoringModel()
885892
epoch_length = 64
886893
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
887-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()], max_epochs=5, logger=False, enable_progress_bar=False)
894+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationBatchEnd()],
895+
max_epochs=5, logger=False, enable_progress_bar=False)
888896
with pytest.raises(RuntimeError, match="Trouble!"):
889897
trainer.fit(model)
890898
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
@@ -900,23 +908,25 @@ def on_validation_epoch_start(self, trainer, pl_module):
900908
model = BoringModel()
901909
epoch_length = 64
902910
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
903-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()], max_epochs=5, logger=False, enable_progress_bar=False)
911+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()],
912+
max_epochs=5, logger=False, enable_progress_bar=False)
904913
with pytest.raises(RuntimeError, match="Trouble!"):
905914
trainer.fit(model)
906915
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
907-
916+
908917

909918
def test_model_checkpoint_save_on_exception_in_val_callback_on_validation_epoch_end(tmp_path):
910919
"""Test that the checkpoint is saved when an exception is raised in a callback on validation_epoch_end."""
911920
class TroublemakerOnValidationEpochEnd(Callback):
912921
def on_validation_epoch_end(self, trainer, pl_module):
913922
if not trainer.sanity_checking and trainer.current_epoch == 0:
914923
raise RuntimeError("Trouble!")
915-
924+
916925
model = BoringModel()
917926
epoch_length = 64
918927
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
919-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()], max_epochs=5, logger=False, enable_progress_bar=False)
928+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()],
929+
max_epochs=5, logger=False, enable_progress_bar=False)
920930
with pytest.raises(RuntimeError, match="Trouble!"):
921931
trainer.fit(model)
922932
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
@@ -927,11 +937,12 @@ class TroublemakerOnValidationStart(Callback):
927937
def on_validation_start(self, trainer, pl_module):
928938
if not trainer.sanity_checking:
929939
raise RuntimeError("Trouble!")
930-
940+
931941
model = BoringModel()
932942
epoch_length = 64
933943
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
934-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()], max_epochs=5, logger=False, enable_progress_bar=False)
944+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationStart()],
945+
max_epochs=5, logger=False, enable_progress_bar=False)
935946
with pytest.raises(RuntimeError, match="Trouble!"):
936947
trainer.fit(model)
937948
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
@@ -942,16 +953,17 @@ class TroublemakerOnValidationEnd(Callback):
942953
def on_validation_end(self, trainer, pl_module):
943954
if not trainer.sanity_checking:
944955
raise RuntimeError("Trouble!")
945-
956+
946957
model = BoringModel()
947958
epoch_length = 64
948959
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
949-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()], max_epochs=5, logger=False, enable_progress_bar=False)
960+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()],
961+
max_epochs=5, logger=False, enable_progress_bar=False)
950962
with pytest.raises(RuntimeError, match="Trouble!"):
951963
trainer.fit(model)
952964
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
953965

954-
966+
955967
@mock.patch("lightning.pytorch.callbacks.model_checkpoint.time")
956968
def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None:
957969
"""Tests that the checkpoints are saved at the specified time interval."""

0 commit comments

Comments
 (0)