Skip to content

Commit 44ebae5

Browse files
committed
update
1 parent cdb51e2 commit 44ebae5

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

tests/tests_pytorch/callbacks/test_callback_hooks.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,42 +58,53 @@ def on_test_batch_end(self, outputs, *_):
5858
trainer.fit(model)
5959

6060

61-
def test_callback_on_before_optimizer_setup(tmp_path):
62-
"""Tests that on_before_optimizer_step is called as expected."""
61+
def test_on_before_optimizer_setup_is_called_in_correct_order(tmp_path):
62+
"""Ensure `on_before_optimizer_setup` runs after `configure_model` but before `configure_optimizers`."""
6363

64-
class CB(Callback):
64+
order = []
65+
66+
class TestCallback(Callback):
6567
def setup(self, trainer, pl_module, stage=None):
68+
order.append("setup")
69+
assert pl_module.layer is None
6670
assert len(trainer.optimizers) == 0
67-
assert pl_module.layer is None # called before `LightningModule.configure_model`
6871

6972
def on_before_optimizer_setup(self, trainer, pl_module):
70-
assert len(trainer.optimizers) == 0 # `LightningModule.configure_optimizers` hasn't been called yet
71-
assert pl_module.layer is not None # called after `LightningModule.configure_model`
73+
order.append("on_before_optimizer_setup")
74+
# configure_model should already have been called
75+
assert pl_module.layer is not None
76+
# but optimizers are not yet created
77+
assert len(trainer.optimizers) == 0
7278

7379
def on_fit_start(self, trainer, pl_module):
80+
order.append("on_fit_start")
81+
# optimizers should now exist
7482
assert len(trainer.optimizers) == 1
75-
assert pl_module.layer is not None # called after `LightningModule.configure_model`
83+
assert pl_module.layer is not None
7684

7785
class DemoModel(BoringModel):
7886
def __init__(self):
7987
super().__init__()
80-
self.layer = None # initialize layer in `configure_model`
88+
self.layer = None
8189

8290
def configure_model(self):
83-
import torch.nn as nn
91+
from torch import nn
8492

8593
self.layer = nn.Linear(32, 2)
8694

8795
model = DemoModel()
8896

8997
trainer = Trainer(
90-
callbacks=CB(),
98+
callbacks=TestCallback(),
9199
default_root_dir=tmp_path,
92100
limit_train_batches=2,
93101
limit_val_batches=2,
94102
max_epochs=1,
95-
log_every_n_steps=1,
96103
enable_model_summary=False,
104+
log_every_n_steps=1,
97105
)
98106

99107
trainer.fit(model)
108+
109+
# Verify call order
110+
assert order == ["setup", "on_before_optimizer_setup", "on_fit_start"]

0 commit comments

Comments
 (0)