Skip to content

Commit f7a5a34

Browse files
committed
feat(ppsci): support data_effient_nopt
1 parent dcdbf60 commit f7a5a34

File tree

4 files changed

+34
-842
lines changed

4 files changed

+34
-842
lines changed

docs/zh/examples/data_efficient_nopt.md

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77
# Download possion_64 data from https://drive.google.com/drive/folders/1crIsTZGxZULWhrXkwGDiWF33W6RHxJkf
88
# Download helmholtz_64 data from https://drive.google.com/drive/folders/1UjIaF6FsjmN_xlGGSUX-1K2V3EF2Zalw
99

10-
# Update the file paths in `config/operators_possion.yaml` or `config/operators_helmholtz.yaml` to specify `train_path`, `val_path`, `test_path`, `scales_path`, and `train_rand_idx_path`.
11-
12-
# possion_64 pretrain
13-
python pretrain_basic.py --run_name r0 --config pois-64-pretrain-e1_20_m3 --yaml_config ./config/operators_poisson.yaml
14-
15-
# possion_64 finetune
16-
python pretrain_basic.py --run_name r0 --config pois-64-e5_15_b0 --yaml_config ./config/operators_poisson.yaml
17-
18-
# helmholtz_64 pretrain
19-
python pretrain_basic.py --run_name r0 --config helm-64-pretrain-o1_20_m1 --yaml_config ./config/operators_helmholtz.yaml
20-
21-
# helmholtz_64 finetune
22-
python pretrain_basic.py --run_name r0 --config helm-64-o5_15_ft5_r2 --yaml_config ./config/operators_helmholtz.yaml
10+
# Update the file paths in `cexamples/data_efficient_nopt/config/data_efficient_nopt.yaml`, specify to mode in `train`
11+
# UPdate the file paths in config/operators_poisson.yaml or config/operators_helmholtz.yaml, specify to `train_path`, `val_path`, `test_path`, `scales_path` and `train_rand_idx_path`
12+
13+
# possion_64 pretrain, specify as following:
14+
# run_name: r0
15+
# config: pois-64-pretrain-e1_20_m3
16+
# yaml_config: config/operators_poisson.yaml
17+
python data_efficient_nopt.py
18+
19+
# helmholtz_64 pretrain, specify as following:
20+
# run_name: r0
21+
# config: helm-64-pretrain-o1_20_m1
22+
# yaml_config: config/operators_helmholtz.yaml
23+
python data_efficient_nopt.py
2324
```
2425

2526
=== "模型评估命令"
@@ -34,9 +35,14 @@
3435

3536
``` sh
3637
cd examples/data_efficient_nopt
37-
# Update the file paths in `config/operators_possion.yaml` or `config/operators_helmholtz.yaml` to specify `train_path`, `test_path`, and `scales_path`.
38+
# Update the file paths in `cexamples/data_efficient_nopt/config/data_efficient_nopt.yaml`, specify to mode in `infer`
3839
# Use a fine-tuned model as the checkpoint in 'exp' or utilize `model_convert.py` to convert the official checkpoint.
39-
python3 inference_fno_helmholtz_poisson.py --config ./config/inference_poisson.yaml --ckpt_path <ckpt_path> --num_demos 1
40+
# UPdate the file paths in config/inference_poisson.yaml or config/inference_poisson.yaml, specify to `train_path`, `test_path` and `scales_path`
41+
42+
# possion_64 inference, specify as following:
43+
# evaluation: config/inference_poisson.yaml
44+
# ckpt_path: <ckpt_path>
45+
python data_efficient_nopt.py
4046
```
4147

4248
## 1. 背景简介

examples/data_efficient_nopt/config/data_efficient_nopt.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# general settings
2-
mode: train # running mode: train/eval
2+
mode: train # running mode: train/infer
33
seed: 42
44

55
# training settings
@@ -10,7 +10,7 @@ yaml_config: config/operators_poisson.yaml
1010
sweep_id: ''
1111

1212
# evaluation settings
13-
eval_config: config/inference_poisson.yaml
13+
infer_config: config/inference_poisson.yaml
1414
ckpt_path: data/pd_finetune_b01_m0_n8192.tar
1515
num_demos: 1
1616
tqdm: False

examples/data_efficient_nopt/data_efficient_nopt.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import yaml
3131
from einops import rearrange
3232
from omegaconf import DictConfig
33-
from pretrain_basic import l2_err
3433
from ruamel.yaml import YAML
3534
from ruamel.yaml.comments import CommentedMap as ruamelDict
3635
from scipy.stats import linregress
@@ -47,6 +46,14 @@
4746
from ppsci.data.dataset.data_efficient_nopt_dataset import PoisHelmDatasetLoader
4847

4948

49+
def l2_err(pred, target, spatial_dim=(-1, -2, -3)):
50+
x = paddle.sum((pred - target) ** 2, axis=spatial_dim) / paddle.sum(
51+
target**2, axis=spatial_dim
52+
)
53+
x = paddle.sqrt(x)
54+
return paddle.mean(x) # , dim=0)
55+
56+
5057
def grad_norm(parameters):
5158
with paddle.no_grad():
5259
total_norm = 0
@@ -961,7 +968,7 @@ def get_pred(cfg):
961968
)
962969

963970

964-
def evaluate(cfg: DictConfig):
971+
def inference(cfg: DictConfig):
965972
get_pred(cfg)
966973

967974

@@ -971,8 +978,8 @@ def evaluate(cfg: DictConfig):
971978
def main(cfg: DictConfig):
972979
if cfg.mode == "train":
973980
train(cfg)
974-
elif cfg.mode == "eval":
975-
evaluate(cfg)
981+
elif cfg.mode == "infer":
982+
inference(cfg)
976983
else:
977984
raise ValueError(
978985
f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"

0 commit comments

Comments
 (0)