Skip to content

Commit 4658299

Browse files
Jeff YangBorda
authored andcommitted
[tests/checkpointing] refactor with BoringModel (#4661)
* [tests/checkpointing] refactor with BoringModel * [tests/checkpointing] refactor with BoringModel * [tests/checkpointing] refactor with BoringModel * LessBoringModel -> LogInTwoMethods * LessBoringModel -> LogInTwoMethods * LessBoringModel -> TrainingStepCalled Co-authored-by: chaton <[email protected]> Co-authored-by: Ananya Harsh Jha <[email protected]> (cherry picked from commit 7d96fd1)
1 parent 8ebd28f commit 4658299

File tree

3 files changed

+112
-57
lines changed

3 files changed

+112
-57
lines changed

tests/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
import torch
1919

2020
from pytorch_lightning import Trainer, callbacks, seed_everything
21-
from tests.base import BoringModel, EvalModelTemplate
21+
from tests.base import BoringModel
2222

2323

2424
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
2525
def test_mc_called_on_fastdevrun(tmpdir):
2626
seed_everything(1234)
2727

28-
train_val_step_model = EvalModelTemplate()
28+
train_val_step_model = BoringModel()
2929

3030
# fast dev run = called once
3131
# train loop only, dict, eval result
@@ -38,7 +38,18 @@ def test_mc_called_on_fastdevrun(tmpdir):
3838
# -----------------------
3939
# also called once with no val step
4040
# -----------------------
41-
train_step_only_model = EvalModelTemplate()
41+
class TrainingStepCalled(BoringModel):
42+
def __init__(self):
43+
super().__init__()
44+
self.training_step_called = False
45+
self.validation_step_called = False
46+
self.test_step_called = False
47+
48+
def training_step(self, batch, batch_idx):
49+
self.training_step_called = True
50+
return super().training_step(batch, batch_idx)
51+
52+
train_step_only_model = TrainingStepCalled()
4253
train_step_only_model.validation_step = None
4354

4455
# fast dev run = called once
@@ -62,7 +73,7 @@ def test_mc_called(tmpdir):
6273
# -----------------
6374
# TRAIN LOOP ONLY
6475
# -----------------
65-
train_step_only_model = EvalModelTemplate()
76+
train_step_only_model = BoringModel()
6677
train_step_only_model.validation_step = None
6778

6879
# no callback
@@ -73,7 +84,7 @@ def test_mc_called(tmpdir):
7384
# -----------------
7485
# TRAIN + VAL LOOP ONLY
7586
# -----------------
76-
val_train_model = EvalModelTemplate()
87+
val_train_model = BoringModel()
7788
# no callback
7889
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
7990
trainer.fit(val_train_model)

0 commit comments

Comments
 (0)