Skip to content

Commit 779048d

Browse files
[Feature] Add loss aggregator to saving/loading process (#949)
* fix document and formulations * correct doc of Data.batch_transform * update schema for None batch size in INFER config * support batch_size to be set to None where certain data is unnecessary to be batchified * update pre-commit * refine EVAL config schema * save&load state_dict for loss aggregator
1 parent 273bfb0 commit 779048d

File tree

10 files changed

+92
-5
lines changed

10 files changed

+92
-5
lines changed

docs/zh/install_setup.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595

9696
``` sh
9797
cd PaddleScience/
98-
set PYTHONPATH=%cd%
98+
set PYTHONPATH=%PYTHONPATH%;%CD%
9999
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple # manually install requirements
100100
```
101101

ppsci/loss/mtl/agda.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import ClassVar
1718
from typing import List
1819

1920
import paddle
@@ -30,6 +31,10 @@ class AGDA(base.LossAggregator):
3031
3132
NOTE: This loss aggregator is only suitable for two-task learning and the first task loss must be PDE loss.
3233
34+
Attributes:
35+
should_persist(bool): Whether to persist the loss aggregator when saving.
36+
Those loss aggregators with parameters and/or buffers should be persisted.
37+
3338
Args:
3439
model (nn.Layer): Training model.
3540
M (int, optional): Smoothing period. Defaults to 100.
@@ -49,6 +54,7 @@ class AGDA(base.LossAggregator):
4954
... bc_loss = paddle.sum((y2 - 2) ** 2)
5055
... loss_aggregator({'pde_loss': pde_loss, 'bc_loss': bc_loss}).backward()
5156
"""
57+
should_persist: ClassVar[bool] = False
5258

5359
def __init__(self, model: nn.Layer, M: int = 100, gamma: float = 0.999) -> None:
5460
super().__init__(model)

ppsci/loss/mtl/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from typing import TYPE_CHECKING
18+
from typing import ClassVar
1819
from typing import Dict
1920
from typing import Union
2021

@@ -27,10 +28,16 @@
2728
class LossAggregator(nn.Layer):
2829
"""Base class of loss aggregator mainly for multitask learning.
2930
31+
Attributes:
32+
should_persist(bool): Whether to persist the loss aggregator when saving.
33+
Those loss aggregators with parameters and/or buffers should be persisted.
34+
3035
Args:
3136
model (nn.Layer): Training model.
3237
"""
3338

39+
should_persist: ClassVar[bool] = False
40+
3441
def __init__(self, model: nn.Layer) -> None:
3542
super().__init__()
3643
self.model = model
@@ -52,3 +59,10 @@ def backward(self) -> None:
5259
raise NotImplementedError(
5360
f"'backward' should be implemented in subclass {self.__class__.__name__}"
5461
)
62+
63+
def state_dict(self):
64+
agg_state = super().state_dict()
65+
model_state = self.model.state_dictq()
66+
# remove model parameters from state dict for already in pdparams
67+
agg_state = {k: v for k, v in agg_state.items() if k not in model_state}
68+
return agg_state

ppsci/loss/mtl/grad_norm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import ClassVar
1718
from typing import Dict
1819
from typing import List
1920

@@ -42,6 +43,10 @@ class GradNorm(base.LossAggregator):
4243
\end{align*}
4344
$$
4445
46+
Attributes:
47+
should_persist(bool): Whether to persist the loss aggregator when saving.
48+
Those loss aggregators with parameters and/or buffers should be persisted.
49+
4550
Args:
4651
model (nn.Layer): Training model.
4752
num_losses (int, optional): Number of losses. Defaults to 1.
@@ -63,6 +68,7 @@ class GradNorm(base.LossAggregator):
6368
... loss2 = paddle.sum((y2 - 2) ** 2)
6469
... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
6570
"""
71+
should_persist: ClassVar[bool] = True
6672
weight: paddle.Tensor
6773

6874
def __init__(

ppsci/loss/mtl/ntk.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,21 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import TYPE_CHECKING
18+
from typing import ClassVar
1719
from typing import List
1820

1921
import paddle
20-
from paddle import nn
2122

2223
from ppsci.loss.mtl import base
2324

25+
if TYPE_CHECKING:
26+
from paddle import nn
27+
2428

2529
class NTK(base.LossAggregator):
30+
should_persist: ClassVar[bool] = True
31+
2632
def __init__(
2733
self,
2834
model: nn.Layer,

ppsci/loss/mtl/pcgrad.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import ClassVar
1718
from typing import List
1819

1920
import numpy as np
@@ -31,6 +32,10 @@ class PCGrad(base.LossAggregator):
3132
3233
Code reference: [https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py](https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py)
3334
35+
Attributes:
36+
should_persist(bool): Whether to persist the loss aggregator when saving.
37+
Those loss aggregators with parameters and/or buffers should be persisted.
38+
3439
Args:
3540
model (nn.Layer): Training model.
3641
@@ -48,6 +53,7 @@ class PCGrad(base.LossAggregator):
4853
... loss2 = paddle.sum((y2 - 2) ** 2)
4954
... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
5055
"""
56+
should_persist: ClassVar[bool] = False
5157

5258
def __init__(self, model: nn.Layer) -> None:
5359
super().__init__(model)

ppsci/loss/mtl/relobralo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import ClassVar
1718
from typing import Dict
1819

1920
import paddle
@@ -26,6 +27,10 @@ class Relobralo(nn.Layer):
2627
2728
[Multi-Objective Loss Balancing for Physics-Informed Deep Learning](https://arxiv.org/abs/2110.09813)
2829
30+
Attributes:
31+
should_persist(bool): Whether to persist the loss aggregator when saving.
32+
Those loss aggregators with parameters and/or buffers should be persisted.
33+
2934
Args:
3035
num_losses (int): Number of losses.
3136
alpha (float, optional): Ability for remembering past in paper. Defaults to 0.95.
@@ -49,6 +54,7 @@ class Relobralo(nn.Layer):
4954
... loss2 = paddle.sum((y2 - 2) ** 2)
5055
... loss_aggregator({'loss1': loss1, 'loss2': loss2}).backward()
5156
"""
57+
should_persist: ClassVar[bool] = True
5258

5359
def __init__(
5460
self,

ppsci/loss/mtl/sum.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
if TYPE_CHECKING:
2121
import paddle
2222

23+
from typing import ClassVar
24+
2325
from ppsci.loss.mtl.base import LossAggregator
2426

2527

@@ -30,7 +32,12 @@ class Sum(LossAggregator):
3032
$$
3133
loss = \sum_i^N losses_i
3234
$$
35+
36+
Attributes:
37+
should_persist(bool): Whether to persist the loss aggregator when saving.
38+
Those loss aggregators with parameters and/or buffers should be persisted.
3339
"""
40+
should_persist: ClassVar[bool] = False
3441

3542
def __init__(self) -> None:
3643
self.step = 0

ppsci/solver/solver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(
334334
self.scaler,
335335
self.equation,
336336
self.ema_model,
337+
self.loss_aggregator,
337338
)
338339
if isinstance(loaded_metric, dict):
339340
self.best_metric.update(loaded_metric)
@@ -567,6 +568,7 @@ def train(self) -> None:
567568
self.output_dir,
568569
"best_model",
569570
self.equation,
571+
aggregator=self.loss_aggregator,
570572
)
571573
logger.info(
572574
f"[Eval][Epoch {epoch_id}]"
@@ -633,6 +635,7 @@ def train(self) -> None:
633635
f"epoch_{epoch_id}",
634636
self.equation,
635637
ema_model=self.ema_model,
638+
aggregator=self.loss_aggregator,
636639
)
637640

638641
# save the latest model for convenient resume training
@@ -646,6 +649,7 @@ def train(self) -> None:
646649
self.equation,
647650
print_log=(epoch_id == start_epoch),
648651
ema_model=self.ema_model,
652+
aggregator=self.loss_aggregator,
649653
)
650654

651655
def finetune(self, pretrained_model_path: str) -> None:

ppsci/utils/save_load.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from paddle import optimizer
3232

3333
from ppsci import equation
34+
from ppsci.loss import mtl
3435
from ppsci.utils import ema
3536

3637

@@ -42,7 +43,10 @@
4243

4344

4445
def _load_pretrain_from_path(
45-
path: str, model: nn.Layer, equation: Optional[Dict[str, equation.PDE]] = None
46+
path: str,
47+
model: nn.Layer,
48+
equation: Optional[Dict[str, equation.PDE]] = None,
49+
loss_aggregator: Optional[mtl.LossAggregator] = None,
4650
):
4751
"""Load pretrained model from given path.
4852
@@ -77,9 +81,26 @@ def _load_pretrain_from_path(
7781
f"Finish loading pretrained equation parameters from: {path}.pdeqn"
7882
)
7983

84+
if loss_aggregator is not None:
85+
if not os.path.exists(f"{path}.pdagg"):
86+
if loss_aggregator.should_persist:
87+
logger.warning(
88+
f"Given loss_aggregator({type(loss_aggregator)}) has persistable"
89+
f"parameters or buffers, but {path}.pdagg not found."
90+
)
91+
else:
92+
aggregator_dict = paddle.load(f"{path}.pdagg")
93+
loss_aggregator.set_state_dict(aggregator_dict)
94+
logger.message(
95+
f"Finish loading pretrained equation parameters from: {path}.pdagg"
96+
)
97+
8098

8199
def load_pretrain(
82-
model: nn.Layer, path: str, equation: Optional[Dict[str, equation.PDE]] = None
100+
model: nn.Layer,
101+
path: str,
102+
equation: Optional[Dict[str, equation.PDE]] = None,
103+
loss_aggregator: Optional[mtl.LossAggregator] = None,
83104
):
84105
"""
85106
Load pretrained model from given path or url.
@@ -121,7 +142,7 @@ def is_url_accessible(url: str):
121142
# remove ".pdparams" in suffix of path for convenient
122143
if path.endswith(".pdparams"):
123144
path = path[:-9]
124-
_load_pretrain_from_path(path, model, equation)
145+
_load_pretrain_from_path(path, model, equation, loss_aggregator)
125146

126147

127148
def load_checkpoint(
@@ -131,6 +152,7 @@ def load_checkpoint(
131152
grad_scaler: Optional[amp.GradScaler] = None,
132153
equation: Optional[Dict[str, equation.PDE]] = None,
133154
ema_model: Optional[ema.AveragedModel] = None,
155+
aggregator: Optional[mtl.LossAggregator] = None,
134156
) -> Dict[str, Any]:
135157
"""Load from checkpoint.
136158
@@ -141,6 +163,7 @@ def load_checkpoint(
141163
grad_scaler (Optional[amp.GradScaler]): GradScaler for AMP. Defaults to None.
142164
equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None.
143165
ema_model: Optional[ema.AveragedModel]: Average model. Defaults to None.
166+
aggregator: Optional[mtl.LossAggregator]: Loss aggregator. Defaults to None.
144167
145168
Returns:
146169
Dict[str, Any]: Loaded metric information.
@@ -189,6 +212,10 @@ def load_checkpoint(
189212
avg_param_dict = paddle.load(f"{path}_ema.pdparams")
190213
ema_model.set_state_dict(avg_param_dict)
191214

215+
if aggregator is not None:
216+
aggregator_dict = paddle.load(f"{path}.pdagg")
217+
aggregator.set_state_dict(aggregator_dict)
218+
192219
logger.message(f"Finish loading checkpoint from {path}")
193220
return metric_dict
194221

@@ -203,6 +230,7 @@ def save_checkpoint(
203230
equation: Optional[Dict[str, equation.PDE]] = None,
204231
print_log: bool = True,
205232
ema_model: Optional[ema.AveragedModel] = None,
233+
aggregator: Optional[mtl.LossAggregator] = None,
206234
):
207235
"""
208236
Save checkpoint, including model params, optimizer params, metric information.
@@ -219,6 +247,7 @@ def save_checkpoint(
219247
keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings.
220248
Defaults to True.
221249
ema_model: Optional[ema.AveragedModel]: Average model. Defaults to None.
250+
aggregator: Optional[mtl.LossAggregator]: Loss aggregator. Defaults to None.
222251
223252
Examples:
224253
>>> import ppsci
@@ -258,6 +287,9 @@ def save_checkpoint(
258287
if ema_model:
259288
paddle.save(ema_model.state_dict(), f"{ckpt_path}_ema.pdparams")
260289

290+
if aggregator and aggregator.should_persist:
291+
paddle.save(aggregator.state_dict(), f"{ckpt_path}.pdagg")
292+
261293
if print_log:
262294
log_str = f"Finish saving checkpoint to: {ckpt_path}"
263295
if prefix == "latest":

0 commit comments

Comments
 (0)