@@ -160,9 +160,11 @@ def load_checkpoint(
160
160
raise FileNotFoundError (f"{ path } .scaler not exist." )
161
161
162
162
# load state dict
163
- param_dict = paddle .load (f"{ path } .pdparams" )
163
+ model_dict = paddle .load (f"{ path } .pdparams" )
164
164
optim_dict = paddle .load (f"{ path } .pdopt" )
165
- metric_dict = paddle .load (f"{ path } .pdstates" )
165
+ metric_dict = {}
166
+ if os .path .exists (f"{ path } .pdstates" ):
167
+ metric_dict = paddle .load (f"{ path } .pdstates" )
166
168
if grad_scaler is not None :
167
169
scaler_dict = paddle .load (f"{ path } .pdscaler" )
168
170
if equation is not None :
@@ -172,9 +174,9 @@ def load_checkpoint(
172
174
else :
173
175
equation_dict = paddle .load (f"{ path } .pdeqn" )
174
176
175
- # set state dict
177
+ # set model state dict
176
178
logger .message (f"* Loading model checkpoint from { path } .pdparams" )
177
- missing_keys , unexpected_keys = model .set_state_dict (param_dict )
179
+ missing_keys , unexpected_keys = model .set_state_dict (model_dict )
178
180
if missing_keys :
179
181
logger .warning (
180
182
f"There are missing keys when loading checkpoint: { missing_keys } , "
@@ -186,20 +188,23 @@ def load_checkpoint(
186
188
"and corresponding weights will be ignored."
187
189
)
188
190
191
+ # set optimizer state dict
189
192
logger .message (f"* Loading optimizer checkpoint from { path } .pdopt" )
190
193
optimizer .set_state_dict (optim_dict )
194
+
191
195
if grad_scaler is not None :
192
196
logger .message (f"* Loading grad scaler checkpoint from { path } .pdscaler" )
193
197
grad_scaler .load_state_dict (scaler_dict )
198
+
194
199
if equation is not None and equation_dict is not None :
195
200
logger .message (f"* Loading equation checkpoint from { path } .pdeqn" )
196
201
for name , _equation in equation .items ():
197
202
_equation .set_state_dict (equation_dict [name ])
198
203
199
- if ema_model :
204
+ if ema_model is not None :
200
205
logger .message (f"* Loading EMA checkpoint from { path } _ema.pdparams" )
201
- avg_param_dict = paddle .load (f"{ path } _ema.pdparams" )
202
- ema_model .set_state_dict (avg_param_dict )
206
+ avg_model_dict = paddle .load (f"{ path } _ema.pdparams" )
207
+ ema_model .set_state_dict (avg_model_dict )
203
208
204
209
if aggregator is not None and aggregator .should_persist :
205
210
logger .message (f"* Loading loss aggregator checkpoint from { path } .pdagg" )
@@ -213,7 +218,7 @@ def load_checkpoint(
213
218
def save_checkpoint (
214
219
model : nn .Layer ,
215
220
optimizer : Optional [optimizer .Optimizer ],
216
- metric : Dict [str , float ],
221
+ metric : Optional [ Dict [str , float ]] = None ,
217
222
grad_scaler : Optional [amp .GradScaler ] = None ,
218
223
output_dir : Optional [str ] = None ,
219
224
prefix : str = "model" ,
@@ -228,7 +233,7 @@ def save_checkpoint(
228
233
Args:
229
234
model (nn.Layer): Model with parameters.
230
235
optimizer (Optional[optimizer.Optimizer]): Optimizer for model.
231
- metric (Dict[str, float]): Metric information, such as {"RMSE": 0.1, "MAE": 0.2}.
236
+ metric (Optional[ Dict[str, float]] ): Metric information, such as {"RMSE": 0.1, "MAE": 0.2}. Defaults to None .
232
237
grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None.
233
238
output_dir (Optional[str]): Directory for checkpoint storage.
234
239
prefix (str, optional): Prefix for storage. Defaults to "model".
@@ -259,11 +264,16 @@ def save_checkpoint(
259
264
os .makedirs (ckpt_dir , exist_ok = True )
260
265
261
266
paddle .save (model .state_dict (), f"{ ckpt_path } .pdparams" )
262
- if optimizer :
267
+
268
+ if optimizer is not None :
263
269
paddle .save (optimizer .state_dict (), f"{ ckpt_path } .pdopt" )
264
- paddle .save (metric , f"{ ckpt_path } .pdstates" )
270
+
271
+ if metric is not None and len (metric ) > 0 :
272
+ paddle .save (metric , f"{ ckpt_path } .pdstates" )
273
+
265
274
if grad_scaler is not None :
266
275
paddle .save (grad_scaler .state_dict (), f"{ ckpt_path } .pdscaler" )
276
+
267
277
if equation is not None :
268
278
num_learnable_params = sum (
269
279
[len (eq .learnable_parameters ) for eq in equation .values ()]
@@ -274,10 +284,10 @@ def save_checkpoint(
274
284
f"{ ckpt_path } .pdeqn" ,
275
285
)
276
286
277
- if ema_model :
287
+ if ema_model is not None :
278
288
paddle .save (ema_model .state_dict (), f"{ ckpt_path } _ema.pdparams" )
279
289
280
- if aggregator and aggregator .should_persist :
290
+ if aggregator is not None and aggregator .should_persist :
281
291
paddle .save (aggregator .state_dict (), f"{ ckpt_path } .pdagg" )
282
292
283
293
if print_log :
0 commit comments