@@ -58,42 +58,53 @@ def on_test_batch_end(self, outputs, *_):
58
58
trainer .fit (model )
59
59
60
60
61
- def test_callback_on_before_optimizer_setup (tmp_path ):
62
- """Tests that on_before_optimizer_step is called as expected ."""
61
+ def test_on_before_optimizer_setup_is_called_in_correct_order (tmp_path ):
62
+ """Ensure `on_before_optimizer_setup` runs after `configure_model` but before `configure_optimizers` ."""
63
63
64
- class CB (Callback ):
64
+ order = []
65
+
66
+ class TestCallback (Callback ):
65
67
def setup (self , trainer , pl_module , stage = None ):
68
+ order .append ("setup" )
69
+ assert pl_module .layer is None
66
70
assert len (trainer .optimizers ) == 0
67
- assert pl_module .layer is None # called before `LightningModule.configure_model`
68
71
69
72
def on_before_optimizer_setup (self , trainer , pl_module ):
70
- assert len (trainer .optimizers ) == 0 # `LightningModule.configure_optimizers` hasn't been called yet
71
- assert pl_module .layer is not None # called after `LightningModule.configure_model`
73
+ order .append ("on_before_optimizer_setup" )
74
+ # configure_model should already have been called
75
+ assert pl_module .layer is not None
76
+ # but optimizers are not yet created
77
+ assert len (trainer .optimizers ) == 0
72
78
73
79
def on_fit_start (self , trainer , pl_module ):
80
+ order .append ("on_fit_start" )
81
+ # optimizers should now exist
74
82
assert len (trainer .optimizers ) == 1
75
- assert pl_module .layer is not None # called after `LightningModule.configure_model`
83
+ assert pl_module .layer is not None
76
84
77
85
class DemoModel (BoringModel ):
78
86
def __init__ (self ):
79
87
super ().__init__ ()
80
- self .layer = None # initialize layer in `configure_model`
88
+ self .layer = None
81
89
82
90
def configure_model (self ):
83
- import torch . nn as nn
91
+ from torch import nn
84
92
85
93
self .layer = nn .Linear (32 , 2 )
86
94
87
95
model = DemoModel ()
88
96
89
97
trainer = Trainer (
90
- callbacks = CB (),
98
+ callbacks = TestCallback (),
91
99
default_root_dir = tmp_path ,
92
100
limit_train_batches = 2 ,
93
101
limit_val_batches = 2 ,
94
102
max_epochs = 1 ,
95
- log_every_n_steps = 1 ,
96
103
enable_model_summary = False ,
104
+ log_every_n_steps = 1 ,
97
105
)
98
106
99
107
trainer .fit (model )
108
+
109
+ # Verify call order
110
+ assert order == ["setup" , "on_before_optimizer_setup" , "on_fit_start" ]
0 commit comments