Skip to content

Commit 755791f

Browse files
update code
1 parent e1cdba5 commit 755791f

File tree

12 files changed

+448
-17
lines changed

12 files changed

+448
-17
lines changed

docs/zh/api/loss/mtl.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
options:
66
members:
77
- AGDA
8+
- GradNorm
89
- LossAggregator
910
- PCGrad
1011
- Relobralo
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""
2+
Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn
3+
"""
4+
5+
from os import path as osp
6+
7+
import hydra
8+
import numpy as np
9+
import paddle
10+
import scipy.io as sio
11+
from matplotlib import pyplot as plt
12+
from omegaconf import DictConfig
13+
14+
import ppsci
15+
from ppsci.loss import mtl
16+
from ppsci.utils import misc
17+
18+
dtype = paddle.get_default_dtype()
19+
20+
21+
def plot(
22+
t_star: np.ndarray,
23+
x_star: np.ndarray,
24+
u_ref: np.ndarray,
25+
u_pred: np.ndarray,
26+
output_dir: str,
27+
):
28+
fig = plt.figure(figsize=(18, 5))
29+
TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
30+
u_ref = u_ref.reshape([len(t_star), len(x_star)])
31+
32+
plt.subplot(1, 3, 1)
33+
plt.pcolor(TT, XX, u_ref, cmap="jet")
34+
plt.colorbar()
35+
plt.xlabel("t")
36+
plt.ylabel("x")
37+
plt.title("Exact")
38+
plt.tight_layout()
39+
40+
plt.subplot(1, 3, 2)
41+
plt.pcolor(TT, XX, u_pred, cmap="jet")
42+
plt.colorbar()
43+
plt.xlabel("t")
44+
plt.ylabel("x")
45+
plt.title("Predicted")
46+
plt.tight_layout()
47+
48+
plt.subplot(1, 3, 3)
49+
plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
50+
plt.colorbar()
51+
plt.xlabel("t")
52+
plt.ylabel("x")
53+
plt.title("Absolute error")
54+
plt.tight_layout()
55+
56+
fig_path = osp.join(output_dir, "ac.png")
57+
print(f"Saving figure to {fig_path}")
58+
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
59+
plt.close()
60+
61+
62+
def train(cfg: DictConfig):
63+
# set model
64+
model = ppsci.arch.MLP(**cfg.MODEL)
65+
66+
# set equation
67+
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}
68+
69+
data = sio.loadmat(cfg.DATA_PATH)
70+
u_ref = data["usol"].astype(dtype) # (nt, nx)
71+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
72+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
73+
74+
u0 = u_ref[0, :] # [nx, ]
75+
76+
t0 = t_star[0] # float
77+
t1 = t_star[-1] # float
78+
79+
x0 = x_star[0] # float
80+
x1 = x_star[-1] # float
81+
82+
# set constraint
83+
def gen_input_batch():
84+
tx = np.random.uniform(
85+
[t0, x0],
86+
[t1, x1],
87+
(cfg.TRAIN.batch_size, 2),
88+
).astype(dtype)
89+
return {
90+
"t": np.sort(tx[:, 0:1], axis=0),
91+
"x": tx[:, 1:2],
92+
}
93+
94+
def gen_label_batch(input_batch):
95+
return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)}
96+
97+
pde_constraint = ppsci.constraint.SupervisedConstraint(
98+
{
99+
"dataset": {
100+
"name": "ContinuousNamedArrayDataset",
101+
"input": gen_input_batch,
102+
"label": gen_label_batch,
103+
},
104+
},
105+
output_expr=equation["AllenCahn"].equations,
106+
loss=ppsci.loss.CausalMSELoss(
107+
cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol
108+
),
109+
name="PDE",
110+
)
111+
112+
ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])}
113+
ic_label = {"u": u0.reshape([-1, 1])}
114+
ic = ppsci.constraint.SupervisedConstraint(
115+
{
116+
"dataset": {
117+
"name": "IterableNamedArrayDataset",
118+
"input": ic_input,
119+
"label": ic_label,
120+
},
121+
},
122+
output_expr={"u": lambda out: out["u"]},
123+
loss=ppsci.loss.MSELoss("mean"),
124+
name="IC",
125+
)
126+
# wrap constraints together
127+
constraint = {
128+
pde_constraint.name: pde_constraint,
129+
ic.name: ic,
130+
}
131+
132+
# set optimizer
133+
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
134+
**cfg.TRAIN.lr_scheduler
135+
)()
136+
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
137+
138+
# set validator
139+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
140+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
141+
eval_label = {"u": u_ref.reshape([-1, 1])}
142+
u_validator = ppsci.validate.SupervisedValidator(
143+
{
144+
"dataset": {
145+
"name": "NamedArrayDataset",
146+
"input": eval_data,
147+
"label": eval_label,
148+
},
149+
"batch_size": cfg.EVAL.batch_size,
150+
},
151+
ppsci.loss.MSELoss("mean"),
152+
{"u": lambda out: out["u"]},
153+
metric={"L2Rel": ppsci.metric.L2Rel()},
154+
name="u_validator",
155+
)
156+
validator = {u_validator.name: u_validator}
157+
158+
# initialize solver
159+
solver = ppsci.solver.Solver(
160+
model,
161+
constraint,
162+
cfg.output_dir,
163+
optimizer,
164+
epochs=cfg.TRAIN.epochs,
165+
iters_per_epoch=cfg.TRAIN.iters_per_epoch,
166+
save_freq=cfg.TRAIN.save_freq,
167+
log_freq=cfg.log_freq,
168+
eval_during_train=True,
169+
eval_freq=cfg.TRAIN.eval_freq,
170+
equation=equation,
171+
validator=validator,
172+
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
173+
checkpoint_path=cfg.TRAIN.checkpoint_path,
174+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
175+
use_tbd=True,
176+
loss_aggregator=mtl.GradNorm(model, len(constraint), 1000, 0.9),
177+
cfg=cfg,
178+
)
179+
# train model
180+
solver.train()
181+
# evaluate after finished training
182+
solver.eval()
183+
# visualize prediction after finished training
184+
u_pred = solver.predict(
185+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
186+
)["u"]
187+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
188+
189+
# plot
190+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
191+
192+
193+
def evaluate(cfg: DictConfig):
194+
# set model
195+
model = ppsci.arch.MLP(**cfg.MODEL)
196+
197+
data = sio.loadmat(cfg.DATA_PATH)
198+
u_ref = data["usol"].astype(dtype) # (nt, nx)
199+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
200+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
201+
202+
# set validator
203+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
204+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
205+
eval_label = {"u": u_ref.reshape([-1, 1])}
206+
u_validator = ppsci.validate.SupervisedValidator(
207+
{
208+
"dataset": {
209+
"name": "NamedArrayDataset",
210+
"input": eval_data,
211+
"label": eval_label,
212+
},
213+
"batch_size": cfg.EVAL.batch_size,
214+
},
215+
ppsci.loss.MSELoss("mean"),
216+
{"u": lambda out: out["u"]},
217+
metric={"L2Rel": ppsci.metric.L2Rel()},
218+
name="u_validator",
219+
)
220+
validator = {u_validator.name: u_validator}
221+
222+
# initialize solver
223+
solver = ppsci.solver.Solver(
224+
model,
225+
output_dir=cfg.output_dir,
226+
log_freq=cfg.log_freq,
227+
validator=validator,
228+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
229+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
230+
)
231+
232+
# evaluate after finished training
233+
solver.eval()
234+
# visualize prediction after finished training
235+
u_pred = solver.predict(
236+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
237+
)["u"]
238+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
239+
240+
# plot
241+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
242+
243+
244+
def export(cfg: DictConfig):
245+
# set model
246+
model = ppsci.arch.MLP(**cfg.MODEL)
247+
248+
# initialize solver
249+
solver = ppsci.solver.Solver(
250+
model,
251+
pretrained_model_path=cfg.INFER.pretrained_model_path,
252+
)
253+
# export model
254+
from paddle.static import InputSpec
255+
256+
input_spec = [
257+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
258+
]
259+
solver.export(input_spec, cfg.INFER.export_path, with_onnx=False)
260+
261+
262+
def inference(cfg: DictConfig):
263+
from deploy.python_infer import pinn_predictor
264+
265+
predictor = pinn_predictor.PINNPredictor(cfg)
266+
data = sio.loadmat(cfg.DATA_PATH)
267+
u_ref = data["usol"].astype(dtype) # (nt, nx)
268+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
269+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
270+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
271+
272+
input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
273+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
274+
output_dict = {
275+
store_key: output_dict[infer_key]
276+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
277+
}
278+
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
279+
# mapping data to cfg.INFER.output_keys
280+
281+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
282+
283+
284+
@hydra.main(
285+
version_base=None, config_path="./conf", config_name="allen_cahn_default.yaml"
286+
)
287+
def main(cfg: DictConfig):
288+
if cfg.mode == "train":
289+
train(cfg)
290+
elif cfg.mode == "eval":
291+
evaluate(cfg)
292+
elif cfg.mode == "export":
293+
export(cfg)
294+
elif cfg.mode == "infer":
295+
inference(cfg)
296+
else:
297+
raise ValueError(
298+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
299+
)
300+
301+
302+
if __name__ == "__main__":
303+
main()

examples/allen_cahn/conf/allen_cahn.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ MODEL:
3939
hidden_size: 256
4040
activation: tanh
4141
periods:
42-
t: [2.0, False]
42+
x: [2.0, False]
4343

4444
# training settings
4545
TRAIN:

examples/allen_cahn/conf/allen_cahn_causal_fourier_rwf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ MODEL:
3939
hidden_size: 256
4040
activation: tanh
4141
periods:
42-
t: [2.0, False]
42+
x: [2.0, False]
4343
fourier:
4444
dim: 256
4545
scale: 1.0

ppsci/arch/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def num_params(self) -> int:
5555
num = 0
5656
for name, param in self.named_parameters():
5757
if hasattr(param, "shape"):
58-
num += np.prod(list(param.shape))
58+
num += np.prod(list(param.shape), dtype="int")
5959
else:
6060
logger.warning(f"{name} has no attribute 'shape'")
6161
return num

ppsci/arch/mlp.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
import math
1817
from typing import Dict
1918
from typing import Optional
2019
from typing import Tuple
@@ -77,12 +76,7 @@ def __init__(
7776

7877
def _init_weights(self, mean, std):
7978
with paddle.no_grad():
80-
# glorot normal
81-
fin, fout = self.weight_v.shape
82-
var = 2.0 / (fin + fout)
83-
stddev = math.sqrt(var) * 0.87962566103423978
84-
initializer.trunc_normal_(self.weight_v)
85-
paddle.assign(self.weight_v * stddev, self.weight_v)
79+
initializer.glorot_normal(self.weight_v)
8680

8781
nn.initializer.Normal(mean, std)(self.weight_g)
8882
paddle.assign(paddle.exp(self.weight_g), self.weight_g)
@@ -105,7 +99,7 @@ def __init__(self, periods: Dict[str, Tuple[float, bool]]):
10599
k: self.create_parameter(
106100
[],
107101
attr=paddle.ParamAttr(trainable=trainable),
108-
default_initializer=nn.initializer.Constant(2 * np.pi / eval(p)),
102+
default_initializer=nn.initializer.Constant(2 * np.pi / float(p)),
109103
) # mu = 2*pi / period for sin/cos function
110104
for k, (p, trainable) in periods.items()
111105
}

ppsci/loss/mtl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
from ppsci.loss.mtl.agda import AGDA
1818
from ppsci.loss.mtl.base import LossAggregator
19+
from ppsci.loss.mtl.grad_norm import GradNorm
1920
from ppsci.loss.mtl.pcgrad import PCGrad
2021
from ppsci.loss.mtl.relobralo import Relobralo
2122
from ppsci.loss.mtl.sum import Sum
2223

2324
__all__ = [
2425
"AGDA",
26+
"GradNorm",
2527
"LossAggregator",
2628
"PCGrad",
2729
"Relobralo",

ppsci/loss/mtl/agda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import paddle
2020
from paddle import nn
2121

22-
from ppsci.loss.mtl.base import LossAggregator
22+
from ppsci.loss.mtl import base
2323

2424

25-
class AGDA(LossAggregator):
25+
class AGDA(base.LossAggregator):
2626
r"""
2727
**A**daptive **G**radient **D**escent **A**lgorithm
2828

0 commit comments

Comments
 (0)