Skip to content

Commit dd971ee

Browse files
authored
Merge pull request #860 from HydrogenSulfate/fix_allen_cahn
[Example][WIP] Add Allen cahn default config
2 parents 68fe0cf + 3b05221 commit dd971ee

File tree

16 files changed

+552
-17
lines changed

16 files changed

+552
-17
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
3636
| 微分方程 | [分数阶微分方程](https://github.com/PaddlePaddle/PaddleScience/blob/develop/examples/fpde/fractional_poisson_2d.py) | 机理驱动 | MLP | 无监督学习 | - | - |
3737
| 光孤子 | [Optical soliton](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/nlsmb) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://doi.org/10.1007/s11071-023-08824-w)|
3838
| 光纤怪波 | [Optical rogue wave](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/nlsmb) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://doi.org/10.1007/s11071-023-08824-w)|
39+
| 相场方程 | [Allen-Cahn](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/allen_cahn) | 机理驱动 | MLP | 无监督学习 | - | |
3940

4041
<br>
4142
<p align="center"><b>技术科学(AI for Technology)</b></p>

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
| 微分方程 | [分数阶微分方程](https://github.com/PaddlePaddle/PaddleScience/blob/develop/examples/fpde/fractional_poisson_2d.py) | 机理驱动 | MLP | 无监督学习 | - | - |
8282
| 光孤子 | [Optical soliton](./zh/examples/nlsmb.md) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://doi.org/10.1007/s11071-023-08824-w)|
8383
| 光纤怪波 | [Optical rogue wave](./zh/examples/nlsmb.md) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://doi.org/10.1007/s11071-023-08824-w)|
84+
| 相场方程 | [Allen-Cahn](./zh/examples/allen_cahn.md) | 机理驱动 | MLP | 无监督学习 | - | |
8485

8586
<br>
8687
<p align="center"><b>技术科学(AI for Technology)</b></p>

docs/zh/api/loss/mtl.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
options:
66
members:
77
- AGDA
8+
- GradNorm
89
- LossAggregator
910
- PCGrad
1011
- Relobralo
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
"""
2+
Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn
3+
"""
4+
5+
from os import path as osp
6+
7+
import hydra
8+
import numpy as np
9+
import paddle
10+
import scipy.io as sio
11+
from matplotlib import pyplot as plt
12+
from omegaconf import DictConfig
13+
14+
import ppsci
15+
from ppsci.loss import mtl
16+
from ppsci.utils import misc
17+
18+
dtype = paddle.get_default_dtype()
19+
20+
21+
def plot(
22+
t_star: np.ndarray,
23+
x_star: np.ndarray,
24+
u_ref: np.ndarray,
25+
u_pred: np.ndarray,
26+
output_dir: str,
27+
):
28+
fig = plt.figure(figsize=(18, 5))
29+
TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
30+
u_ref = u_ref.reshape([len(t_star), len(x_star)])
31+
32+
plt.subplot(1, 3, 1)
33+
plt.pcolor(TT, XX, u_ref, cmap="jet")
34+
plt.colorbar()
35+
plt.xlabel("t")
36+
plt.ylabel("x")
37+
plt.title("Exact")
38+
plt.tight_layout()
39+
40+
plt.subplot(1, 3, 2)
41+
plt.pcolor(TT, XX, u_pred, cmap="jet")
42+
plt.colorbar()
43+
plt.xlabel("t")
44+
plt.ylabel("x")
45+
plt.title("Predicted")
46+
plt.tight_layout()
47+
48+
plt.subplot(1, 3, 3)
49+
plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
50+
plt.colorbar()
51+
plt.xlabel("t")
52+
plt.ylabel("x")
53+
plt.title("Absolute error")
54+
plt.tight_layout()
55+
56+
fig_path = osp.join(output_dir, "ac.png")
57+
print(f"Saving figure to {fig_path}")
58+
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
59+
plt.close()
60+
61+
62+
def train(cfg: DictConfig):
63+
# set model
64+
model = ppsci.arch.MLP(**cfg.MODEL)
65+
66+
# set equation
67+
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}
68+
69+
# set constraint
70+
data = sio.loadmat(cfg.DATA_PATH)
71+
u_ref = data["usol"].astype(dtype) # (nt, nx)
72+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
73+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
74+
75+
u0 = u_ref[0, :] # [nx, ]
76+
77+
t0 = t_star[0] # float
78+
t1 = t_star[-1] # float
79+
80+
x0 = x_star[0] # float
81+
x1 = x_star[-1] # float
82+
83+
def gen_input_batch():
84+
tx = np.random.uniform(
85+
[t0, x0],
86+
[t1, x1],
87+
(cfg.TRAIN.batch_size, 2),
88+
).astype(dtype)
89+
return {
90+
"t": np.sort(tx[:, 0:1], axis=0),
91+
"x": tx[:, 1:2],
92+
}
93+
94+
def gen_label_batch(input_batch):
95+
return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)}
96+
97+
pde_constraint = ppsci.constraint.SupervisedConstraint(
98+
{
99+
"dataset": {
100+
"name": "ContinuousNamedArrayDataset",
101+
"input": gen_input_batch,
102+
"label": gen_label_batch,
103+
},
104+
},
105+
output_expr=equation["AllenCahn"].equations,
106+
loss=ppsci.loss.CausalMSELoss(
107+
cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol
108+
),
109+
name="PDE",
110+
)
111+
112+
ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])}
113+
ic_label = {"u": u0.reshape([-1, 1])}
114+
ic = ppsci.constraint.SupervisedConstraint(
115+
{
116+
"dataset": {
117+
"name": "IterableNamedArrayDataset",
118+
"input": ic_input,
119+
"label": ic_label,
120+
},
121+
},
122+
output_expr={"u": lambda out: out["u"]},
123+
loss=ppsci.loss.MSELoss("mean"),
124+
name="IC",
125+
)
126+
# wrap constraints together
127+
constraint = {
128+
pde_constraint.name: pde_constraint,
129+
ic.name: ic,
130+
}
131+
132+
# set optimizer
133+
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
134+
**cfg.TRAIN.lr_scheduler
135+
)()
136+
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
137+
138+
# set validator
139+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
140+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
141+
eval_label = {"u": u_ref.reshape([-1, 1])}
142+
u_validator = ppsci.validate.SupervisedValidator(
143+
{
144+
"dataset": {
145+
"name": "NamedArrayDataset",
146+
"input": eval_data,
147+
"label": eval_label,
148+
},
149+
"batch_size": cfg.EVAL.batch_size,
150+
},
151+
ppsci.loss.MSELoss("mean"),
152+
{"u": lambda out: out["u"]},
153+
metric={"L2Rel": ppsci.metric.L2Rel()},
154+
name="u_validator",
155+
)
156+
validator = {u_validator.name: u_validator}
157+
158+
# initialize solver
159+
solver = ppsci.solver.Solver(
160+
model,
161+
constraint,
162+
cfg.output_dir,
163+
optimizer,
164+
epochs=cfg.TRAIN.epochs,
165+
iters_per_epoch=cfg.TRAIN.iters_per_epoch,
166+
save_freq=cfg.TRAIN.save_freq,
167+
log_freq=cfg.log_freq,
168+
eval_during_train=True,
169+
eval_freq=cfg.TRAIN.eval_freq,
170+
equation=equation,
171+
validator=validator,
172+
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
173+
checkpoint_path=cfg.TRAIN.checkpoint_path,
174+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
175+
loss_aggregator=mtl.GradNorm(
176+
model,
177+
len(constraint),
178+
cfg.TRAIN.grad_norm.update_freq,
179+
cfg.TRAIN.grad_norm.momentum,
180+
),
181+
cfg=cfg,
182+
)
183+
# train model
184+
solver.train()
185+
# evaluate after finished training
186+
solver.eval()
187+
# visualize prediction after finished training
188+
u_pred = solver.predict(
189+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
190+
)["u"]
191+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
192+
193+
# plot
194+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
195+
196+
197+
def evaluate(cfg: DictConfig):
198+
# set model
199+
model = ppsci.arch.MLP(**cfg.MODEL)
200+
201+
data = sio.loadmat(cfg.DATA_PATH)
202+
u_ref = data["usol"].astype(dtype) # (nt, nx)
203+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
204+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
205+
206+
# set validator
207+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
208+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
209+
eval_label = {"u": u_ref.reshape([-1, 1])}
210+
u_validator = ppsci.validate.SupervisedValidator(
211+
{
212+
"dataset": {
213+
"name": "NamedArrayDataset",
214+
"input": eval_data,
215+
"label": eval_label,
216+
},
217+
"batch_size": cfg.EVAL.batch_size,
218+
},
219+
ppsci.loss.MSELoss("mean"),
220+
{"u": lambda out: out["u"]},
221+
metric={"L2Rel": ppsci.metric.L2Rel()},
222+
name="u_validator",
223+
)
224+
validator = {u_validator.name: u_validator}
225+
226+
# initialize solver
227+
solver = ppsci.solver.Solver(
228+
model,
229+
output_dir=cfg.output_dir,
230+
log_freq=cfg.log_freq,
231+
validator=validator,
232+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
233+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
234+
)
235+
236+
# evaluate after finished training
237+
solver.eval()
238+
# visualize prediction after finished training
239+
u_pred = solver.predict(
240+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
241+
)["u"]
242+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
243+
244+
# plot
245+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
246+
247+
248+
def export(cfg: DictConfig):
249+
# set model
250+
model = ppsci.arch.MLP(**cfg.MODEL)
251+
252+
# initialize solver
253+
solver = ppsci.solver.Solver(
254+
model,
255+
pretrained_model_path=cfg.INFER.pretrained_model_path,
256+
)
257+
# export model
258+
from paddle.static import InputSpec
259+
260+
input_spec = [
261+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
262+
]
263+
solver.export(input_spec, cfg.INFER.export_path, with_onnx=False)
264+
265+
266+
def inference(cfg: DictConfig):
267+
from deploy.python_infer import pinn_predictor
268+
269+
predictor = pinn_predictor.PINNPredictor(cfg)
270+
data = sio.loadmat(cfg.DATA_PATH)
271+
u_ref = data["usol"].astype(dtype) # (nt, nx)
272+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
273+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
274+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
275+
276+
input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
277+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
278+
output_dict = {
279+
store_key: output_dict[infer_key]
280+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
281+
}
282+
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
283+
# mapping data to cfg.INFER.output_keys
284+
285+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
286+
287+
288+
@hydra.main(
289+
version_base=None, config_path="./conf", config_name="allen_cahn_default.yaml"
290+
)
291+
def main(cfg: DictConfig):
292+
if cfg.mode == "train":
293+
train(cfg)
294+
elif cfg.mode == "eval":
295+
evaluate(cfg)
296+
elif cfg.mode == "export":
297+
export(cfg)
298+
elif cfg.mode == "infer":
299+
inference(cfg)
300+
else:
301+
raise ValueError(
302+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
303+
)
304+
305+
306+
if __name__ == "__main__":
307+
main()

examples/allen_cahn/conf/allen_cahn.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ MODEL:
3939
hidden_size: 256
4040
activation: tanh
4141
periods:
42-
t: [2.0, False]
42+
x: [2.0, False]
4343

4444
# training settings
4545
TRAIN:

examples/allen_cahn/conf/allen_cahn_causal_fourier_rwf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ MODEL:
3939
hidden_size: 256
4040
activation: tanh
4141
periods:
42-
t: [2.0, False]
42+
x: [2.0, False]
4343
fourier:
4444
dim: 256
4545
scale: 1.0

0 commit comments

Comments
 (0)