Skip to content

Commit f8be42e

Browse files
refine save_load (#1158)
1 parent 8e1f922 commit f8be42e

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

ppsci/utils/save_load.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,11 @@ def load_checkpoint(
160160
raise FileNotFoundError(f"{path}.scaler not exist.")
161161

162162
# load state dict
163-
param_dict = paddle.load(f"{path}.pdparams")
163+
model_dict = paddle.load(f"{path}.pdparams")
164164
optim_dict = paddle.load(f"{path}.pdopt")
165-
metric_dict = paddle.load(f"{path}.pdstates")
165+
metric_dict = {}
166+
if os.path.exists(f"{path}.pdstates"):
167+
metric_dict = paddle.load(f"{path}.pdstates")
166168
if grad_scaler is not None:
167169
scaler_dict = paddle.load(f"{path}.pdscaler")
168170
if equation is not None:
@@ -172,9 +174,9 @@ def load_checkpoint(
172174
else:
173175
equation_dict = paddle.load(f"{path}.pdeqn")
174176

175-
# set state dict
177+
# set model state dict
176178
logger.message(f"* Loading model checkpoint from {path}.pdparams")
177-
missing_keys, unexpected_keys = model.set_state_dict(param_dict)
179+
missing_keys, unexpected_keys = model.set_state_dict(model_dict)
178180
if missing_keys:
179181
logger.warning(
180182
f"There are missing keys when loading checkpoint: {missing_keys}, "
@@ -186,20 +188,23 @@ def load_checkpoint(
186188
"and corresponding weights will be ignored."
187189
)
188190

191+
# set optimizer state dict
189192
logger.message(f"* Loading optimizer checkpoint from {path}.pdopt")
190193
optimizer.set_state_dict(optim_dict)
194+
191195
if grad_scaler is not None:
192196
logger.message(f"* Loading grad scaler checkpoint from {path}.pdscaler")
193197
grad_scaler.load_state_dict(scaler_dict)
198+
194199
if equation is not None and equation_dict is not None:
195200
logger.message(f"* Loading equation checkpoint from {path}.pdeqn")
196201
for name, _equation in equation.items():
197202
_equation.set_state_dict(equation_dict[name])
198203

199-
if ema_model:
204+
if ema_model is not None:
200205
logger.message(f"* Loading EMA checkpoint from {path}_ema.pdparams")
201-
avg_param_dict = paddle.load(f"{path}_ema.pdparams")
202-
ema_model.set_state_dict(avg_param_dict)
206+
avg_model_dict = paddle.load(f"{path}_ema.pdparams")
207+
ema_model.set_state_dict(avg_model_dict)
203208

204209
if aggregator is not None and aggregator.should_persist:
205210
logger.message(f"* Loading loss aggregator checkpoint from {path}.pdagg")
@@ -213,7 +218,7 @@ def load_checkpoint(
213218
def save_checkpoint(
214219
model: nn.Layer,
215220
optimizer: Optional[optimizer.Optimizer],
216-
metric: Dict[str, float],
221+
metric: Optional[Dict[str, float]] = None,
217222
grad_scaler: Optional[amp.GradScaler] = None,
218223
output_dir: Optional[str] = None,
219224
prefix: str = "model",
@@ -228,7 +233,7 @@ def save_checkpoint(
228233
Args:
229234
model (nn.Layer): Model with parameters.
230235
optimizer (Optional[optimizer.Optimizer]): Optimizer for model.
231-
metric (Dict[str, float]): Metric information, such as {"RMSE": 0.1, "MAE": 0.2}.
236+
metric (Optional[Dict[str, float]]): Metric information, such as {"RMSE": 0.1, "MAE": 0.2}. Defaults to None.
232237
grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None.
233238
output_dir (Optional[str]): Directory for checkpoint storage.
234239
prefix (str, optional): Prefix for storage. Defaults to "model".
@@ -259,11 +264,16 @@ def save_checkpoint(
259264
os.makedirs(ckpt_dir, exist_ok=True)
260265

261266
paddle.save(model.state_dict(), f"{ckpt_path}.pdparams")
262-
if optimizer:
267+
268+
if optimizer is not None:
263269
paddle.save(optimizer.state_dict(), f"{ckpt_path}.pdopt")
264-
paddle.save(metric, f"{ckpt_path}.pdstates")
270+
271+
if metric is not None and len(metric) > 0:
272+
paddle.save(metric, f"{ckpt_path}.pdstates")
273+
265274
if grad_scaler is not None:
266275
paddle.save(grad_scaler.state_dict(), f"{ckpt_path}.pdscaler")
276+
267277
if equation is not None:
268278
num_learnable_params = sum(
269279
[len(eq.learnable_parameters) for eq in equation.values()]
@@ -274,10 +284,10 @@ def save_checkpoint(
274284
f"{ckpt_path}.pdeqn",
275285
)
276286

277-
if ema_model:
287+
if ema_model is not None:
278288
paddle.save(ema_model.state_dict(), f"{ckpt_path}_ema.pdparams")
279289

280-
if aggregator and aggregator.should_persist:
290+
if aggregator is not None and aggregator.should_persist:
281291
paddle.save(aggregator.state_dict(), f"{ckpt_path}.pdagg")
282292

283293
if print_log:

0 commit comments

Comments
 (0)