Skip to content

Commit f52394d

Browse files
[Example] Add allen_cahn causal train with fourier feature and random weight factorization (#848)
* allow empty optimizer when saving checkpoint * add model averaging module * fix return dtype inconsistency with global dtype * use python func instead of sympy function for pow(u,3) get a bit poor L2 error than multiply(u*u*u) * refine AllenCahn docstring * support save and load for average model module * add 3 ema unitests * update 2023 to 2024 * add ema config pydantic scheme * add avg_range for SWA * update field_validator for swa and ema * support period embedding for MLP * Keep non-float data when reading file * update ema and save_load, printer and eval, solver module code * add allen_cahn example * refine code * save buffer and non-grad required params in ema * add unitest for ema with buffer * fix epoch_ema saving * add unitest for ema state_dict * refine allen_cahn_plain.py * fix string to floating conversion in reader.py * fix string to floating conversion in reader.py * update code and refine document * correct initialization for RWF * update docstring for arg 'random_weight' of mlp * update docstrings * add causal fourier rwf config * fix code in mlp.py * refine code in mse.py
1 parent c30b6e0 commit f52394d

File tree

10 files changed

+597
-24
lines changed

10 files changed

+597
-24
lines changed

docs/zh/api/loss/loss.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- L2RelLoss
1212
- MAELoss
1313
- MSELoss
14+
- CausalMSELoss
1415
- MSELossWithL2Decay
1516
- IntegralLoss
1617
- PeriodicL1Loss
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""
2+
Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn
3+
"""
4+
5+
from os import path as osp
6+
7+
import hydra
8+
import numpy as np
9+
import paddle
10+
import scipy.io as sio
11+
from matplotlib import pyplot as plt
12+
from omegaconf import DictConfig
13+
14+
import ppsci
15+
from ppsci.utils import misc
16+
17+
dtype = paddle.get_default_dtype()
18+
19+
20+
def plot(
21+
t_star: np.ndarray,
22+
x_star: np.ndarray,
23+
u_ref: np.ndarray,
24+
u_pred: np.ndarray,
25+
output_dir: str,
26+
):
27+
fig = plt.figure(figsize=(18, 5))
28+
TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
29+
u_ref = u_ref.reshape([len(t_star), len(x_star)])
30+
31+
plt.subplot(1, 3, 1)
32+
plt.pcolor(TT, XX, u_ref, cmap="jet")
33+
plt.colorbar()
34+
plt.xlabel("t")
35+
plt.ylabel("x")
36+
plt.title("Exact")
37+
plt.tight_layout()
38+
39+
plt.subplot(1, 3, 2)
40+
plt.pcolor(TT, XX, u_pred, cmap="jet")
41+
plt.colorbar()
42+
plt.xlabel("t")
43+
plt.ylabel("x")
44+
plt.title("Predicted")
45+
plt.tight_layout()
46+
47+
plt.subplot(1, 3, 3)
48+
plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
49+
plt.colorbar()
50+
plt.xlabel("t")
51+
plt.ylabel("x")
52+
plt.title("Absolute error")
53+
plt.tight_layout()
54+
55+
fig_path = osp.join(output_dir, "ac.png")
56+
print(f"Saving figure to {fig_path}")
57+
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
58+
plt.close()
59+
60+
61+
def train(cfg: DictConfig):
62+
# set model
63+
model = ppsci.arch.MLP(**cfg.MODEL)
64+
65+
# set equation
66+
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}
67+
68+
data = sio.loadmat(cfg.DATA_PATH)
69+
u_ref = data["usol"].astype(dtype) # (nt, nx)
70+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
71+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
72+
73+
u0 = u_ref[0, :] # [nx, ]
74+
75+
t0 = t_star[0] # float
76+
t1 = t_star[-1] # float
77+
78+
x0 = x_star[0] # float
79+
x1 = x_star[-1] # float
80+
81+
# set constraint
82+
def gen_input_batch():
83+
tx = np.random.uniform(
84+
[t0, x0],
85+
[t1, x1],
86+
(cfg.TRAIN.batch_size, 2),
87+
).astype(dtype)
88+
return {
89+
"t": np.sort(tx[:, 0:1], axis=0),
90+
"x": tx[:, 1:2],
91+
}
92+
93+
def gen_label_batch(input_batch):
94+
return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)}
95+
96+
pde_constraint = ppsci.constraint.SupervisedConstraint(
97+
{
98+
"dataset": {
99+
"name": "ContinuousNamedArrayDataset",
100+
"input": gen_input_batch,
101+
"label": gen_label_batch,
102+
},
103+
},
104+
output_expr=equation["AllenCahn"].equations,
105+
loss=ppsci.loss.CausalMSELoss(
106+
cfg.TRAIN.causal.n_chunks, "mean", tol=cfg.TRAIN.causal.tol
107+
),
108+
name="PDE",
109+
)
110+
111+
ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])}
112+
ic_label = {"u": u0.reshape([-1, 1])}
113+
ic = ppsci.constraint.SupervisedConstraint(
114+
{
115+
"dataset": {
116+
"name": "IterableNamedArrayDataset",
117+
"input": ic_input,
118+
"label": ic_label,
119+
},
120+
},
121+
output_expr={"u": lambda out: out["u"]},
122+
loss=ppsci.loss.MSELoss("mean"),
123+
name="IC",
124+
)
125+
# wrap constraints together
126+
constraint = {
127+
pde_constraint.name: pde_constraint,
128+
ic.name: ic,
129+
}
130+
131+
# set optimizer
132+
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
133+
**cfg.TRAIN.lr_scheduler
134+
)()
135+
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
136+
137+
# set validator
138+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
139+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
140+
eval_label = {"u": u_ref.reshape([-1, 1])}
141+
u_validator = ppsci.validate.SupervisedValidator(
142+
{
143+
"dataset": {
144+
"name": "NamedArrayDataset",
145+
"input": eval_data,
146+
"label": eval_label,
147+
},
148+
"batch_size": cfg.EVAL.batch_size,
149+
},
150+
ppsci.loss.MSELoss("mean"),
151+
{"u": lambda out: out["u"]},
152+
metric={"L2Rel": ppsci.metric.L2Rel()},
153+
name="u_validator",
154+
)
155+
validator = {u_validator.name: u_validator}
156+
157+
# initialize solver
158+
solver = ppsci.solver.Solver(
159+
model,
160+
constraint,
161+
cfg.output_dir,
162+
optimizer,
163+
lr_scheduler,
164+
cfg.TRAIN.epochs,
165+
cfg.TRAIN.iters_per_epoch,
166+
save_freq=cfg.TRAIN.save_freq,
167+
log_freq=cfg.log_freq,
168+
eval_during_train=True,
169+
eval_freq=cfg.TRAIN.eval_freq,
170+
seed=cfg.seed,
171+
equation=equation,
172+
validator=validator,
173+
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
174+
checkpoint_path=cfg.TRAIN.checkpoint_path,
175+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
176+
use_tbd=True,
177+
cfg=cfg,
178+
)
179+
# train model
180+
solver.train()
181+
# evaluate after finished training
182+
solver.eval()
183+
# visualize prediction after finished training
184+
u_pred = solver.predict(
185+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
186+
)["u"]
187+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
188+
189+
# plot
190+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
191+
192+
193+
def evaluate(cfg: DictConfig):
194+
# set model
195+
model = ppsci.arch.MLP(**cfg.MODEL)
196+
197+
data = sio.loadmat(cfg.DATA_PATH)
198+
u_ref = data["usol"].astype(dtype) # (nt, nx)
199+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
200+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
201+
202+
# set validator
203+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
204+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
205+
eval_label = {"u": u_ref.reshape([-1, 1])}
206+
u_validator = ppsci.validate.SupervisedValidator(
207+
{
208+
"dataset": {
209+
"name": "NamedArrayDataset",
210+
"input": eval_data,
211+
"label": eval_label,
212+
},
213+
"batch_size": cfg.EVAL.batch_size,
214+
},
215+
ppsci.loss.MSELoss("mean"),
216+
{"u": lambda out: out["u"]},
217+
metric={"L2Rel": ppsci.metric.L2Rel()},
218+
name="u_validator",
219+
)
220+
validator = {u_validator.name: u_validator}
221+
222+
# initialize solver
223+
solver = ppsci.solver.Solver(
224+
model,
225+
output_dir=cfg.output_dir,
226+
log_freq=cfg.log_freq,
227+
validator=validator,
228+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
229+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
230+
)
231+
232+
# evaluate after finished training
233+
solver.eval()
234+
# visualize prediction after finished training
235+
u_pred = solver.predict(
236+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
237+
)["u"]
238+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
239+
240+
# plot
241+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
242+
243+
244+
def export(cfg: DictConfig):
245+
# set model
246+
model = ppsci.arch.MLP(**cfg.MODEL)
247+
248+
# initialize solver
249+
solver = ppsci.solver.Solver(
250+
model,
251+
pretrained_model_path=cfg.INFER.pretrained_model_path,
252+
)
253+
# export model
254+
from paddle.static import InputSpec
255+
256+
input_spec = [
257+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
258+
]
259+
solver.export(input_spec, cfg.INFER.export_path, with_onnx=False)
260+
261+
262+
def inference(cfg: DictConfig):
263+
from deploy.python_infer import pinn_predictor
264+
265+
predictor = pinn_predictor.PINNPredictor(cfg)
266+
data = sio.loadmat(cfg.DATA_PATH)
267+
u_ref = data["usol"].astype(dtype) # (nt, nx)
268+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
269+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
270+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
271+
272+
input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
273+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
274+
output_dict = {
275+
store_key: output_dict[infer_key]
276+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
277+
}
278+
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
279+
# mapping data to cfg.INFER.output_keys
280+
281+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
282+
283+
284+
@hydra.main(
285+
version_base=None, config_path="./conf", config_name="allen_cahn_causal.yaml"
286+
)
287+
def main(cfg: DictConfig):
288+
if cfg.mode == "train":
289+
train(cfg)
290+
elif cfg.mode == "eval":
291+
evaluate(cfg)
292+
elif cfg.mode == "export":
293+
export(cfg)
294+
elif cfg.mode == "infer":
295+
inference(cfg)
296+
else:
297+
raise ValueError(
298+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
299+
)
300+
301+
302+
if __name__ == "__main__":
303+
main()

examples/allen_cahn/allen_cahn_plain.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Reference: https://docs.nvidia.com/deeplearning/modulus/modulus-v2209/user_guide/intermediate/adding_stl_files.html
2+
Reference: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/main/examples/allen_cahn
33
"""
44

55
from os import path as osp
@@ -53,6 +53,7 @@ def plot(
5353
plt.tight_layout()
5454

5555
fig_path = osp.join(output_dir, "ac.png")
56+
print(f"Saving figure to {fig_path}")
5657
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
5758
plt.close()
5859

@@ -101,7 +102,7 @@ def gen_label_batch(input_batch):
101102
},
102103
},
103104
output_expr=equation["AllenCahn"].equations,
104-
loss=ppsci.loss.MSELoss(),
105+
loss=ppsci.loss.MSELoss("mean"),
105106
name="PDE",
106107
)
107108

examples/allen_cahn/conf/allen_cahn.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ TRAIN:
5858
batch_size: 4096
5959
pretrained_model_path: null
6060
checkpoint_path: null
61-
ema:
62-
decay: 0.9
63-
avg_freq: 1
6461

6562
# evaluation settings
6663
EVAL:

0 commit comments

Comments
 (0)