@@ -40,6 +40,7 @@ def save_checkpoint(
40
40
model : torch .nn .Module ,
41
41
optimizer : torch .optim .Optimizer ,
42
42
scheduler : torch .optim .lr_scheduler .LambdaLR ,
43
+ outer_scheduler : torch .optim .lr_scheduler .LambdaLR | None = None ,
43
44
outer_optimizer : torch .optim .Optimizer | None = None ,
44
45
scaler : torch .cuda .amp .GradScaler | None = None ,
45
46
loss : float | None = None ,
@@ -81,6 +82,8 @@ def save_checkpoint(
81
82
82
83
# 2. Save global states
83
84
global_state_dict = {"scheduler" : scheduler .state_dict (), "loss" : loss if loss is not None else 0 }
85
+ if outer_scheduler is not None :
86
+ global_state_dict ["outer_scheduler" ] = outer_scheduler .state_dict ()
84
87
if outer_optimizer is not None :
85
88
global_state_dict ["outer_optimizer" ] = outer_optimizer .state_dict ()
86
89
if scaler is not None :
@@ -95,6 +98,7 @@ def load_checkpoint(
95
98
model : torch .nn .Module ,
96
99
optimizer : torch .optim .Optimizer ,
97
100
scheduler : torch .optim .lr_scheduler .LambdaLR | None = None ,
101
+ outer_scheduler : torch .optim .lr_scheduler .LambdaLR | None = None ,
98
102
outer_optimizer : torch .optim .Optimizer | None = None ,
99
103
scaler : torch .cuda .amp .GradScaler | None = None ,
100
104
data_loader : StatefulDataLoader | None = None ,
@@ -139,8 +143,13 @@ def load_checkpoint(
139
143
if scheduler is not None :
140
144
scheduler .load_state_dict (global_state_dict ["scheduler" ])
141
145
optimizer .param_groups [0 ]["lr" ] = scheduler .get_last_lr ()[0 ]
146
+
142
147
if outer_optimizer is not None :
143
148
outer_optimizer .load_state_dict (global_state_dict ["outer_optimizer" ])
149
+ if outer_scheduler is not None :
150
+ outer_scheduler .load_state_dict (global_state_dict ["outer_scheduler" ])
151
+ outer_optimizer .param_groups [0 ]["lr" ] = outer_scheduler .get_last_lr ()[0 ]
152
+
144
153
if scaler is not None :
145
154
scaler .load_state_dict (global_state_dict ["scaler" ])
146
155
return global_state_dict ["loss" ]
0 commit comments