Skip to content

Commit b1d8a3d

Browse files
committed
feat(ppsci): support data_effient_nopt for inference
1 parent 996df12 commit b1d8a3d

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

examples/data_efficient_nopt/inference_fno_helmholtz_poisson.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def get_pred(args):
6666
model = build_fno(params)
6767

6868
if args.ckpt_path:
69-
raise NotImplementedError("Loading checkpoint is not supported")
7069
checkpoint = paddle.load(args.ckpt_path)
7170
try:
7271
model.set_state_dict(checkpoint["model_state"])
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import paddle
2+
import torch
3+
4+
5+
def save_checkpoint(checkpoint_path, model):
6+
"""Save model and optimizer to checkpoint"""
7+
8+
paddle.save(
9+
{"model_state": model},
10+
checkpoint_path,
11+
)
12+
13+
14+
def torch2paddle():
15+
torch_path = "/home/aistudio/data_efficient_nopt/data/possion_64_inference/finetune_b01_m0_n8192.tar"
16+
paddle_path = "./data/pd_finetune_b01_m0_n8192.tar"
17+
18+
torch_state_dict = torch.load(torch_path)["model_state"]
19+
# model.set_state_dict(checkpoint["model_state"])
20+
fc_names = ["classifier", "fc"]
21+
paddle_state_dict = {}
22+
import pdb
23+
24+
pdb.set_trace()
25+
for k in torch_state_dict:
26+
if "num_batches_tracked" in k: # 飞桨中无此参数,无需保存
27+
continue
28+
v = torch_state_dict[k].detach().cpu().numpy()
29+
flag = [i in k for i in fc_names]
30+
if any(flag) and "weight" in k:
31+
new_shape = [1, 0] + list(range(2, v.ndim))
32+
print(
33+
f"name: {k}, ori shape: {v.shape}, new shape: {v.transpose(new_shape).shape}"
34+
)
35+
v = v.transpose(new_shape) # 转置 Linear 层的 weight 参数
36+
# 将 torch.nn.BatchNorm2d 的参数名称改成 paddle.nn.BatchNorm2D 对应的参数名称
37+
k = k.replace("running_var", "_variance")
38+
k = k.replace("running_mean", "_mean")
39+
k = k.replace("module.", "")
40+
# 添加到飞桨权重字典中
41+
# k = k
42+
print(f"k: {k}")
43+
paddle_state_dict[k] = v
44+
print(f"paddle_state_dict: {paddle_state_dict.keys()}")
45+
save_checkpoint(paddle_path, paddle_state_dict)
46+
47+
48+
if __name__ == "__main__":
49+
torch2paddle()

examples/data_efficient_nopt/run_inference_possion.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ set -ex
44

55
CUDA_VISIBLE_DEVICES=0 python3.9 inference_fno_helmholtz_poisson.py \
66
--config config/inference_poisson.yaml \
7-
--ckpt_path /home/aistudio/data_efficient_nopt/data/possion_64_inference/finetune_b01_m0_n8192.tar \
7+
--ckpt_path /home/aistudio/xiaoyewww-data/PaddleScience/examples/data_efficient_nopt/data/pd_finetune_b01_m0_n8192.tar \
88
--num_demos 1

0 commit comments

Comments
 (0)