-
Notifications
You must be signed in to change notification settings - Fork 225
【PPSCI Export&Infer No.1】 Add export and inference for deephpms_schrodinger -part #1170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
88a4117
91e8423
d2db081
eb54b93
5d009cc
a553b46
8269102
040da82
48ce0bb
7c15562
fd12e01
ce4f096
feecb0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -535,14 +535,196 @@ def transform_fg(_in): | |||||||
) | ||||||||
|
||||||||
|
||||||||
def export(cfg: DictConfig): | ||||||||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||||||||
# initialize logger | ||||||||
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
# initialize boundaries | ||||||||
t_lb = paddle.to_tensor(cfg.T_LB) | ||||||||
t_ub = paddle.to_tensor(np.pi / cfg.T_UB) | ||||||||
x_lb = paddle.to_tensor(cfg.X_LB) | ||||||||
x_ub = paddle.to_tensor(cfg.X_UB) | ||||||||
|
||||||||
# initialize models | ||||||||
model_idn_u = ppsci.arch.MLP(**cfg.MODEL.idn_u_net) | ||||||||
model_idn_v = ppsci.arch.MLP(**cfg.MODEL.idn_v_net) | ||||||||
|
||||||||
# initialize transform | ||||||||
def transform_uv(_in): | ||||||||
t, x = _in["t"], _in["x"] | ||||||||
t = 2.0 * (t - t_lb) * paddle.pow((t_ub - t_lb), -1) - 1.0 | ||||||||
x = 2.0 * (x - x_lb) * paddle.pow((x_ub - x_lb), -1) - 1.0 | ||||||||
input_trans = {"t": t, "x": x} | ||||||||
return input_trans | ||||||||
|
||||||||
# register transform | ||||||||
model_idn_u.register_input_transform(transform_uv) | ||||||||
model_idn_v.register_input_transform(transform_uv) | ||||||||
|
||||||||
# initialize model list | ||||||||
model_list = ppsci.arch.ModelList((model_idn_u, model_idn_v)) | ||||||||
|
||||||||
# 加载预训练模型 | ||||||||
save_load.load_pretrain(model_list, cfg.INFER.pretrained_model_path) | ||||||||
model_list.eval() | ||||||||
|
||||||||
# 确保导出目录存在 | ||||||||
import json | ||||||||
import os | ||||||||
|
||||||||
export_dir = os.path.dirname(cfg.INFER.export_path) | ||||||||
if export_dir and not os.path.exists(export_dir): | ||||||||
os.makedirs(export_dir, exist_ok=True) | ||||||||
|
||||||||
# 保存模型参数 | ||||||||
params_save_path = cfg.INFER.export_path + "_params.pdparams" | ||||||||
paddle.save(model_list.state_dict(), params_save_path) | ||||||||
|
||||||||
# 保存模型结构信息 | ||||||||
model_info = { | ||||||||
"model_type": "PaddleScience_Schrodinger", | ||||||||
"input_keys": ["t", "x"], | ||||||||
"output_keys": ["u_idn", "v_idn"], | ||||||||
"boundaries": { | ||||||||
"T_LB": float(cfg.T_LB), | ||||||||
"T_UB": float(cfg.T_UB), | ||||||||
"X_LB": float(cfg.X_LB), | ||||||||
"X_UB": float(cfg.X_UB), | ||||||||
}, | ||||||||
} | ||||||||
|
||||||||
info_save_path = cfg.INFER.export_path + "_info.json" | ||||||||
with open(info_save_path, "w", encoding="utf-8") as f: | ||||||||
json.dump(model_info, f, indent=2, ensure_ascii=False) | ||||||||
|
||||||||
logger.info(f"Model exported to {params_save_path} and {info_save_path}") | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 模型导出是指把模型保存为静态图模型,请参考其它有export的案例代码 |
||||||||
|
||||||||
|
||||||||
def inference(cfg: DictConfig): | ||||||||
ppsci.utils.misc.set_random_seed(cfg.seed) | ||||||||
# initialize logger | ||||||||
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") | ||||||||
|
||||||||
# initialize boundaries | ||||||||
t_lb = paddle.to_tensor(cfg.T_LB) | ||||||||
t_ub = paddle.to_tensor(np.pi / cfg.T_UB) | ||||||||
x_lb = paddle.to_tensor(cfg.X_LB) | ||||||||
x_ub = paddle.to_tensor(cfg.X_UB) | ||||||||
|
||||||||
# initialize models - 只初始化需要的模型 | ||||||||
model_idn_u = ppsci.arch.MLP(**cfg.MODEL.idn_u_net) | ||||||||
model_idn_v = ppsci.arch.MLP(**cfg.MODEL.idn_v_net) | ||||||||
|
||||||||
# initialize transform | ||||||||
def transform_uv(_in): | ||||||||
t, x = _in["t"], _in["x"] | ||||||||
t = 2.0 * (t - t_lb) * paddle.pow((t_ub - t_lb), -1) - 1.0 | ||||||||
x = 2.0 * (x - x_lb) * paddle.pow((x_ub - x_lb), -1) - 1.0 | ||||||||
input_trans = {"t": t, "x": x} | ||||||||
return input_trans | ||||||||
|
||||||||
# register transform | ||||||||
model_idn_u.register_input_transform(transform_uv) | ||||||||
model_idn_v.register_input_transform(transform_uv) | ||||||||
|
||||||||
# initialize model list - 只包含需要的模型 | ||||||||
model_list = ppsci.arch.ModelList((model_idn_u, model_idn_v)) | ||||||||
|
||||||||
# load pretrained model | ||||||||
save_load.load_pretrain(model_list, cfg.INFER.pretrained_model_path) | ||||||||
|
||||||||
# 尝试加载数据集以获得与eval相同的网格点 | ||||||||
try: | ||||||||
dataset_path = getattr(cfg, "INFER.pretrained_model_path", "./datasets/NLS.mat") | ||||||||
if not osp.exists(dataset_path): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 模型推理是指加载静态图模型进行推理,而不是动态图模型 |
||||||||
dataset_path = "./datasets/NLS.mat" | ||||||||
|
||||||||
dataset_val = reader.load_mat_file( | ||||||||
dataset_path, | ||||||||
keys=("t", "x", "uv_sol", "u_sol", "v_sol"), | ||||||||
alias_dict={ | ||||||||
"t": "t_ori", | ||||||||
"x": "x_ori", | ||||||||
"uv_sol": "Exact_uv_ori", | ||||||||
"u_sol": "u_star", | ||||||||
"v_sol": "v_star", | ||||||||
}, | ||||||||
) | ||||||||
|
||||||||
# 使用数据集中的网格点 | ||||||||
t_mesh, x_mesh = np.meshgrid( | ||||||||
np.squeeze(dataset_val["t"]), np.squeeze(dataset_val["x"]) | ||||||||
) | ||||||||
|
||||||||
except Exception: | ||||||||
# 回退到默认网格 | ||||||||
t_points = np.linspace(cfg.T_LB, cfg.T_UB, cfg.INFER.t_points) | ||||||||
x_points = np.linspace(cfg.X_LB, cfg.X_UB, cfg.INFER.x_points) | ||||||||
t_mesh, x_mesh = np.meshgrid(t_points, x_points) | ||||||||
dataset_val = None | ||||||||
|
||||||||
t_flatten = paddle.to_tensor( | ||||||||
t_mesh.flatten()[:, None], dtype=paddle.get_default_dtype(), stop_gradient=False | ||||||||
) | ||||||||
x_flatten = paddle.to_tensor( | ||||||||
x_mesh.flatten()[:, None], dtype=paddle.get_default_dtype(), stop_gradient=False | ||||||||
) | ||||||||
|
||||||||
# 使用模型进行预测 | ||||||||
pred = model_list({"t": t_flatten, "x": x_flatten}) | ||||||||
u_pred = pred["u_idn"].numpy() | ||||||||
v_pred = pred["v_idn"].numpy() | ||||||||
uv_pred = np.sqrt(u_pred**2 + v_pred**2) | ||||||||
|
||||||||
# 保存结果 | ||||||||
np.savez( | ||||||||
osp.join(cfg.output_dir, "inference_results.npz"), | ||||||||
t=t_mesh.flatten(), | ||||||||
x=x_mesh.flatten(), | ||||||||
u=u_pred, | ||||||||
v=v_pred, | ||||||||
uv=uv_pred, | ||||||||
) | ||||||||
|
||||||||
# 可视化 | ||||||||
plot_points = paddle.concat([t_flatten, x_flatten], axis=-1).numpy() | ||||||||
data_exact = ( | ||||||||
dataset_val["uv_sol"] if dataset_val is not None else np.zeros_like(uv_pred) | ||||||||
) | ||||||||
figname = ( | ||||||||
"schrodinger_uv_infer_with_exact" | ||||||||
if dataset_val is not None | ||||||||
else "schrodinger_uv_infer" | ||||||||
) | ||||||||
|
||||||||
plot_func.draw_and_save( | ||||||||
figname=figname, | ||||||||
data_exact=data_exact, | ||||||||
data_learned=uv_pred, | ||||||||
boundary=[cfg.T_LB, cfg.T_UB, cfg.X_LB, cfg.X_UB], | ||||||||
griddata_points=plot_points, | ||||||||
griddata_xi=(t_mesh, x_mesh), | ||||||||
save_path=cfg.output_dir, | ||||||||
) | ||||||||
|
||||||||
logger.info(f"Inference completed. Results saved to {cfg.output_dir}") | ||||||||
|
||||||||
|
||||||||
@hydra.main(version_base=None, config_path="./conf", config_name="schrodinger.yaml") | ||||||||
def main(cfg: DictConfig): | ||||||||
if cfg.mode == "train": | ||||||||
train(cfg) | ||||||||
elif cfg.mode == "eval": | ||||||||
evaluate(cfg) | ||||||||
elif cfg.mode == "export": | ||||||||
export(cfg) | ||||||||
elif cfg.mode == "infer": | ||||||||
inference(cfg) | ||||||||
else: | ||||||||
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") | ||||||||
raise ValueError( | ||||||||
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'" | ||||||||
) | ||||||||
|
||||||||
|
||||||||
if __name__ == "__main__": | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.