Skip to content

Commit beae584

Browse files
[Example] Add allen cahn example (#845)
* allow empty optimizer when saving checkpoint * add model averaging module * fix return dtype inconsistency with global dtype * use python func instead of sympy function for pow(u,3) get a bit poor L2 error than multiply(u*u*u) * refine AllenCahn docstring * support save and load for average model module * add 3 ema unitests * update 2023 to 2024 * add ema config pydantic scheme * add avg_range for SWA * update field_validator for swa and ema * support period embedding for MLP * Keep non-float data when reading file * update ema and save_load, printer and eval, solver module code * add allen_cahn example * refine code * save buffer and non-grad required params in ema * add unitest for ema with buffer * fix epoch_ema saving * add unitest for ema state_dict * refine allen_cahn_plain.py * fix string to floating conversion in reader.py * fix string to floating conversion in reader.py * remove print code in solver * Update allen_cahn_plain.py * Update misc.py --------- Co-authored-by: zzm <[email protected]>
1 parent 825a44e commit beae584

File tree

20 files changed

+1197
-29
lines changed

20 files changed

+1197
-29
lines changed

deploy/python_infer/pinn_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def predict(
9797

9898
# inference by batch
9999
for batch_id in range(1, batch_num + 1):
100-
if batch_id % self.log_freq == 0 or batch_id == batch_num:
100+
if batch_id == 1 or batch_id % self.log_freq == 0 or batch_id == batch_num:
101101
logger.info(f"Predicting batch {batch_id}/{batch_num}")
102102

103103
# prepare batch input dict

docs/zh/api/data/dataset.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- ChipHeatDataset
1010
- CSVDataset
1111
- IterableCSVDataset
12+
- ContinuousNamedArrayDataset
1213
- ERA5Dataset
1314
- ERA5SampledDataset
1415
- IterableMatDataset
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
"""
2+
Reference: https://docs.nvidia.com/deeplearning/modulus/modulus-v2209/user_guide/intermediate/adding_stl_files.html
3+
"""
4+
5+
from os import path as osp
6+
7+
import hydra
8+
import numpy as np
9+
import paddle
10+
import scipy.io as sio
11+
from matplotlib import pyplot as plt
12+
from omegaconf import DictConfig
13+
14+
import ppsci
15+
from ppsci.utils import misc
16+
17+
dtype = paddle.get_default_dtype()
18+
19+
20+
def plot(
21+
t_star: np.ndarray,
22+
x_star: np.ndarray,
23+
u_ref: np.ndarray,
24+
u_pred: np.ndarray,
25+
output_dir: str,
26+
):
27+
fig = plt.figure(figsize=(18, 5))
28+
TT, XX = np.meshgrid(t_star, x_star, indexing="ij")
29+
u_ref = u_ref.reshape([len(t_star), len(x_star)])
30+
31+
plt.subplot(1, 3, 1)
32+
plt.pcolor(TT, XX, u_ref, cmap="jet")
33+
plt.colorbar()
34+
plt.xlabel("t")
35+
plt.ylabel("x")
36+
plt.title("Exact")
37+
plt.tight_layout()
38+
39+
plt.subplot(1, 3, 2)
40+
plt.pcolor(TT, XX, u_pred, cmap="jet")
41+
plt.colorbar()
42+
plt.xlabel("t")
43+
plt.ylabel("x")
44+
plt.title("Predicted")
45+
plt.tight_layout()
46+
47+
plt.subplot(1, 3, 3)
48+
plt.pcolor(TT, XX, np.abs(u_ref - u_pred), cmap="jet")
49+
plt.colorbar()
50+
plt.xlabel("t")
51+
plt.ylabel("x")
52+
plt.title("Absolute error")
53+
plt.tight_layout()
54+
55+
fig_path = osp.join(output_dir, "ac.png")
56+
fig.savefig(fig_path, bbox_inches="tight", dpi=400)
57+
plt.close()
58+
59+
60+
def train(cfg: DictConfig):
61+
# set model
62+
model = ppsci.arch.MLP(**cfg.MODEL)
63+
64+
# set equation
65+
equation = {"AllenCahn": ppsci.equation.AllenCahn(0.01**2)}
66+
67+
data = sio.loadmat(cfg.DATA_PATH)
68+
u_ref = data["usol"].astype(dtype) # (nt, nx)
69+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
70+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
71+
72+
u0 = u_ref[0, :] # [nx, ]
73+
74+
t0 = t_star[0] # float
75+
t1 = t_star[-1] # float
76+
77+
x0 = x_star[0] # float
78+
x1 = x_star[-1] # float
79+
80+
# set constraint
81+
def gen_input_batch():
82+
tx = np.random.uniform(
83+
[t0, x0],
84+
[t1, x1],
85+
(cfg.TRAIN.batch_size, 2),
86+
).astype(dtype)
87+
return {
88+
"t": tx[:, 0:1],
89+
"x": tx[:, 1:2],
90+
}
91+
92+
def gen_label_batch(input_batch):
93+
return {"allen_cahn": np.zeros([cfg.TRAIN.batch_size, 1], dtype)}
94+
95+
pde_constraint = ppsci.constraint.SupervisedConstraint(
96+
{
97+
"dataset": {
98+
"name": "ContinuousNamedArrayDataset",
99+
"input": gen_input_batch,
100+
"label": gen_label_batch,
101+
},
102+
},
103+
output_expr=equation["AllenCahn"].equations,
104+
loss=ppsci.loss.MSELoss(),
105+
name="PDE",
106+
)
107+
108+
ic_input = {"t": np.full([len(x_star), 1], t0), "x": x_star.reshape([-1, 1])}
109+
ic_label = {"u": u0.reshape([-1, 1])}
110+
ic = ppsci.constraint.SupervisedConstraint(
111+
{
112+
"dataset": {
113+
"name": "IterableNamedArrayDataset",
114+
"input": ic_input,
115+
"label": ic_label,
116+
},
117+
},
118+
output_expr={"u": lambda out: out["u"]},
119+
loss=ppsci.loss.MSELoss("mean"),
120+
name="IC",
121+
)
122+
# wrap constraints together
123+
constraint = {
124+
pde_constraint.name: pde_constraint,
125+
ic.name: ic,
126+
}
127+
128+
# set optimizer
129+
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
130+
**cfg.TRAIN.lr_scheduler
131+
)()
132+
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
133+
134+
# set validator
135+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
136+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
137+
eval_label = {"u": u_ref.reshape([-1, 1])}
138+
u_validator = ppsci.validate.SupervisedValidator(
139+
{
140+
"dataset": {
141+
"name": "NamedArrayDataset",
142+
"input": eval_data,
143+
"label": eval_label,
144+
},
145+
"batch_size": cfg.EVAL.batch_size,
146+
},
147+
ppsci.loss.MSELoss("mean"),
148+
{"u": lambda out: out["u"]},
149+
metric={"L2Rel": ppsci.metric.L2Rel()},
150+
name="u_validator",
151+
)
152+
validator = {u_validator.name: u_validator}
153+
154+
# initialize solver
155+
solver = ppsci.solver.Solver(
156+
model,
157+
constraint,
158+
cfg.output_dir,
159+
optimizer,
160+
lr_scheduler,
161+
cfg.TRAIN.epochs,
162+
cfg.TRAIN.iters_per_epoch,
163+
save_freq=cfg.TRAIN.save_freq,
164+
log_freq=cfg.log_freq,
165+
eval_during_train=True,
166+
eval_freq=cfg.TRAIN.eval_freq,
167+
seed=cfg.seed,
168+
equation=equation,
169+
validator=validator,
170+
pretrained_model_path=cfg.TRAIN.pretrained_model_path,
171+
checkpoint_path=cfg.TRAIN.checkpoint_path,
172+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
173+
use_tbd=True,
174+
cfg=cfg,
175+
)
176+
# train model
177+
solver.train()
178+
# evaluate after finished training
179+
solver.eval()
180+
# visualize prediction after finished training
181+
u_pred = solver.predict(
182+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
183+
)["u"]
184+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
185+
186+
# plot
187+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
188+
189+
190+
def evaluate(cfg: DictConfig):
191+
# set model
192+
model = ppsci.arch.MLP(**cfg.MODEL)
193+
194+
data = sio.loadmat(cfg.DATA_PATH)
195+
u_ref = data["usol"].astype(dtype) # (nt, nx)
196+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
197+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
198+
199+
# set validator
200+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
201+
eval_data = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
202+
eval_label = {"u": u_ref.reshape([-1, 1])}
203+
u_validator = ppsci.validate.SupervisedValidator(
204+
{
205+
"dataset": {
206+
"name": "NamedArrayDataset",
207+
"input": eval_data,
208+
"label": eval_label,
209+
},
210+
"batch_size": cfg.EVAL.batch_size,
211+
},
212+
ppsci.loss.MSELoss("mean"),
213+
{"u": lambda out: out["u"]},
214+
metric={"L2Rel": ppsci.metric.L2Rel()},
215+
name="u_validator",
216+
)
217+
validator = {u_validator.name: u_validator}
218+
219+
# initialize solver
220+
solver = ppsci.solver.Solver(
221+
model,
222+
output_dir=cfg.output_dir,
223+
log_freq=cfg.log_freq,
224+
validator=validator,
225+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
226+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
227+
)
228+
229+
# evaluate after finished training
230+
solver.eval()
231+
# visualize prediction after finished training
232+
u_pred = solver.predict(
233+
eval_data, batch_size=cfg.EVAL.batch_size, return_numpy=True
234+
)["u"]
235+
u_pred = u_pred.reshape([len(t_star), len(x_star)])
236+
237+
# plot
238+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
239+
240+
241+
def export(cfg: DictConfig):
242+
# set model
243+
model = ppsci.arch.MLP(**cfg.MODEL)
244+
245+
# initialize solver
246+
solver = ppsci.solver.Solver(
247+
model,
248+
pretrained_model_path=cfg.INFER.pretrained_model_path,
249+
)
250+
# export model
251+
from paddle.static import InputSpec
252+
253+
input_spec = [
254+
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys},
255+
]
256+
solver.export(input_spec, cfg.INFER.export_path, with_onnx=False)
257+
258+
259+
def inference(cfg: DictConfig):
260+
from deploy.python_infer import pinn_predictor
261+
262+
predictor = pinn_predictor.PINNPredictor(cfg)
263+
data = sio.loadmat(cfg.DATA_PATH)
264+
u_ref = data["usol"].astype(dtype) # (nt, nx)
265+
t_star = data["t"].flatten().astype(dtype) # [nt, ]
266+
x_star = data["x"].flatten().astype(dtype) # [nx, ]
267+
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)
268+
269+
input_dict = {"t": tx_star[:, 0:1], "x": tx_star[:, 1:2]}
270+
output_dict = predictor.predict(input_dict, cfg.INFER.batch_size)
271+
output_dict = {
272+
store_key: output_dict[infer_key]
273+
for store_key, infer_key in zip(cfg.MODEL.output_keys, output_dict.keys())
274+
}
275+
u_pred = output_dict["u"].reshape([len(t_star), len(x_star)])
276+
# mapping data to cfg.INFER.output_keys
277+
278+
plot(t_star, x_star, u_ref, u_pred, cfg.output_dir)
279+
280+
281+
@hydra.main(version_base=None, config_path="./conf", config_name="allen_cahn.yaml")
282+
def main(cfg: DictConfig):
283+
if cfg.mode == "train":
284+
train(cfg)
285+
elif cfg.mode == "eval":
286+
evaluate(cfg)
287+
elif cfg.mode == "export":
288+
export(cfg)
289+
elif cfg.mode == "infer":
290+
inference(cfg)
291+
else:
292+
raise ValueError(
293+
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
294+
)
295+
296+
297+
if __name__ == "__main__":
298+
main()

0 commit comments

Comments
 (0)