@@ -58,42 +58,53 @@ def on_test_batch_end(self, outputs, *_):
5858 trainer .fit (model )
5959
6060
61- def test_callback_on_before_optimizer_setup (tmp_path ):
62- """Tests that on_before_optimizer_step is called as expected ."""
61+ def test_on_before_optimizer_setup_is_called_in_correct_order (tmp_path ):
62+ """Ensure `on_before_optimizer_setup` runs after `configure_model` but before `configure_optimizers` ."""
6363
64- class CB (Callback ):
64+ order = []
65+
66+ class TestCallback (Callback ):
6567 def setup (self , trainer , pl_module , stage = None ):
68+ order .append ("setup" )
69+ assert pl_module .layer is None
6670 assert len (trainer .optimizers ) == 0
67- assert pl_module .layer is None # called before `LightningModule.configure_model`
6871
6972 def on_before_optimizer_setup (self , trainer , pl_module ):
70- assert len (trainer .optimizers ) == 0 # `LightningModule.configure_optimizers` hasn't been called yet
71- assert pl_module .layer is not None # called after `LightningModule.configure_model`
73+ order .append ("on_before_optimizer_setup" )
74+ # configure_model should already have been called
75+ assert pl_module .layer is not None
76+ # but optimizers are not yet created
77+ assert len (trainer .optimizers ) == 0
7278
7379 def on_fit_start (self , trainer , pl_module ):
80+ order .append ("on_fit_start" )
81+ # optimizers should now exist
7482 assert len (trainer .optimizers ) == 1
75- assert pl_module .layer is not None # called after `LightningModule.configure_model`
83+ assert pl_module .layer is not None
7684
7785 class DemoModel (BoringModel ):
7886 def __init__ (self ):
7987 super ().__init__ ()
80- self .layer = None # initialize layer in `configure_model`
88+ self .layer = None
8189
8290 def configure_model (self ):
83- import torch . nn as nn
91+ from torch import nn
8492
8593 self .layer = nn .Linear (32 , 2 )
8694
8795 model = DemoModel ()
8896
8997 trainer = Trainer (
90- callbacks = CB (),
98+ callbacks = TestCallback (),
9199 default_root_dir = tmp_path ,
92100 limit_train_batches = 2 ,
93101 limit_val_batches = 2 ,
94102 max_epochs = 1 ,
95- log_every_n_steps = 1 ,
96103 enable_model_summary = False ,
104+ log_every_n_steps = 1 ,
97105 )
98106
99107 trainer .fit (model )
108+
109+ # Verify call order
110+ assert order == ["setup" , "on_before_optimizer_setup" , "on_fit_start" ]
0 commit comments