Skip to content

Commit 2113acc

Browse files
committed
split test for save checksave point on expection for expetions in training part of callbacks in individal test for better overview
1 parent e0dae53 commit 2113acc

File tree

1 file changed

+94
-20
lines changed

1 file changed

+94
-20
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -764,52 +764,126 @@ def test_ckpt_every_n_train_steps(tmp_path):
764764
assert set(os.listdir(tmp_path)) == set(expected)
765765

766766

767-
def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path):
768-
"""Test that the checkpoint is saved when an exception is raised in a callback on different events."""
767+
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_start(tmp_path):
768+
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_start."""
769769
class TroublemakerOnTrainBatchStart(Callback):
770770
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
771771
if batch_idx == 1:
772772
raise RuntimeError("Trouble!")
773773

774+
model = BoringModel()
775+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
776+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchStart()], max_epochs=5, logger=False)
777+
with pytest.raises(RuntimeError, match="Trouble!"):
778+
trainer.fit(model)
779+
assert os.path.isfile(tmp_path / "step=1.ckpt")
780+
781+
782+
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_batch_end(tmp_path):
783+
"""Test that the checkpoint is saved when an exception is raised in a callback on train_batch_end."""
774784
class TroublemakerOnTrainBatchEnd(Callback):
775785
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
776786
if batch_idx == 1:
777787
raise RuntimeError("Trouble!")
788+
789+
model = BoringModel()
790+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
791+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainBatchEnd()], max_epochs=5, logger=False)
792+
with pytest.raises(RuntimeError, match="Trouble!"):
793+
trainer.fit(model)
778794

795+
assert os.path.isfile(tmp_path / "step=2.ckpt")
796+
797+
798+
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_start(tmp_path):
799+
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_start."""
779800
class TroublemakerOnTrainEpochStart(Callback):
780801
def on_train_epoch_start(self, trainer, pl_module):
781802
if trainer.current_epoch == 1:
782803
raise RuntimeError("Trouble!")
804+
805+
model = BoringModel()
806+
epoch_length = 64
807+
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
808+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()], max_epochs=5, logger=False)
809+
with pytest.raises(RuntimeError, match="Trouble!"):
810+
trainer.fit(model)
811+
assert os.path.isfile(tmp_path / f"step={epoch_length}.ckpt")
812+
783813

814+
def test_model_checkpoint_save_on_exception_in_train_callback_on_train_epoch_end(tmp_path):
815+
"""Test that the checkpoint is saved when an exception is raised in a callback on train_epoch_end."""
784816
class TroublemakerOnTrainEpochEnd(Callback):
785817
def on_train_epoch_end(self, trainer, pl_module):
786818
if trainer.current_epoch == 1:
787819
raise RuntimeError("Trouble!")
788820

789-
790-
epoch_length = 64
791821
model = BoringModel()
792-
# use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
822+
epoch_length = 64
793823
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
824+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()], max_epochs=5, logger=False)
825+
with pytest.raises(RuntimeError, match="Trouble!"):
826+
trainer.fit(model)
827+
assert os.path.isfile(tmp_path / f"step={2*epoch_length}.ckpt")
828+
829+
830+
# def test_model_checkpoint_save_on_exception_in_train_callback(tmp_path):
831+
# """Test that the checkpoint is saved when an exception is raised in a callback on different events."""
832+
# class TroublemakerOnTrainBatchStart(Callback):
833+
# def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
834+
# if batch_idx == 1:
835+
# raise RuntimeError("Trouble!")
836+
837+
# class TroublemakerOnTrainBatchEnd(Callback):
838+
# def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
839+
# if batch_idx == 1:
840+
# raise RuntimeError("Trouble!")
841+
842+
# class TroublemakerOnTrainEpochStart(Callback):
843+
# def on_train_epoch_start(self, trainer, pl_module):
844+
# if trainer.current_epoch == 1:
845+
# raise RuntimeError("Trouble!")
846+
847+
# class TroublemakerOnTrainEpochEnd(Callback):
848+
# def on_train_epoch_end(self, trainer, pl_module):
849+
# if trainer.current_epoch == 1:
850+
# raise RuntimeError("Trouble!")
851+
852+
853+
# epoch_length = 64
854+
# model = BoringModel()
855+
# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
856+
# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
857+
858+
# troublemakers = [
859+
# TroublemakerOnTrainBatchStart(),
860+
# TroublemakerOnTrainBatchEnd(),
861+
# TroublemakerOnTrainEpochStart(),
862+
# TroublemakerOnTrainEpochEnd()
863+
# ]
864+
865+
# expected_ckpts = ["step=1.ckpt",
866+
# 'step=2.ckpt',
867+
# f'step={epoch_length}.ckpt',
868+
# f'step={2*epoch_length}.ckpt',
869+
# ]
870+
871+
# for troublemaker in troublemakers:
872+
# trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
873+
874+
# with pytest.raises(RuntimeError, match="Trouble!"):
875+
# trainer.fit(model)
794876

795-
troublemakers = [
796-
TroublemakerOnTrainBatchStart(),
797-
TroublemakerOnTrainBatchEnd(),
798-
TroublemakerOnTrainEpochStart(),
799-
TroublemakerOnTrainEpochEnd()
800-
]
877+
# assert set(os.listdir(tmp_path)) == set(expected_ckpts)
801878

802-
expected_ckpts = ["step=1.ckpt",
803-
'step=2.ckpt',
804-
f'step={epoch_length}.ckpt',
805-
f'step={2*epoch_length}.ckpt',
806-
]
807879

808-
for troublemaker in troublemakers:
809-
trainer = Trainer(default_root_dir=tmp_path, callbacks=[checkpoint_callback, troublemaker], max_epochs=5, logger=False)
880+
# # use every_n_epochs so that we can differentiate between the normal and the troublemaker checkpoints
881+
# checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
810882

811-
with pytest.raises(RuntimeError, match="Trouble!"):
812-
trainer.fit(model)
883+
# troublemakers = [
884+
# # TroublemakerOnValidationBatchStart(),
885+
# TroublemakerOnValidationBatchEnd(),
886+
# expected_ckpts = [f"step={2*epoch_length}.ckpt",
813887

814888
assert set(os.listdir(tmp_path)) == set(expected_ckpts)
815889

0 commit comments

Comments
 (0)