Skip to content

Commit 2706d22

Browse files
【Hackathon 7th PPSCI No.4】NO.4 DrivAerNet 论文复现 (#1047)
* Create ReadME.md * Add files via upload * Delete examples/fsi/DrivAerNet directory * Create ReadME.md * Add files via upload * drivaernet * Delete examples/fsi/DrivAerNet-paddle-convert-main directory * modify some style * modify yaml * update example drivaernet code * update requirments and delete fsi/DrivAerNet.py etc. * Update drivaernet_dataset.py * update drivaernet_dataset.py * update drivaernet_dataset.py * Update __init__.py * Update __init__.py * Update DrivAerNet.py * Update DrivAerNet.py * Update __init__.py * update * support fraction of the training data * modify some error about version * Update regdgcnn.py * Delete examples/DrivAerNet/requirments.txt * Delete docs/zh/examples/drivaernet directory * Update DriveAerNet.yaml * Rename DrivAerNet.md to drivaernet.md * Update mkdocs.yml * Update drivaernet.md * Update DriveAerNet.yaml * Update DrivAerNet.py * Update DriveAerNet.yaml * Update DrivAerNet.py * Update DrivAerNet.py * Update drivaernet.md * Update DrivAerNet.py * Update DriveAerNet.yaml * Update drivaernet_dataset.py * Update solver.py * update metric.md * Update metric.md * Update optimizer.md * Update arch.md * Update DrivAerNetDataset dataset.md * Update drivaernet_dataset.py * Update drivaernet.md * Update dataset.md * Update arch.md * Update lr_scheduler.py * Delete docs/zh/api/arch.md * Delete docs/zh/api/data/dataset.md * Delete ppsci/optimizer/lr_scheduler.py * Update optimizer.md * Delete docs/zh/api/optimizer.md * Create arch.md * Create dataset.md * Create optimizer.md * Update lr_scheduler.md * Create lr_scheduler.py * Update solver.py * Update drivaernet.md * Update arch.md * Update dataset.md * Update lr_scheduler.py * Update lr_scheduler.md * Rename DriveAerNet.yaml to driveaernet.yaml * Update and rename DrivAerNet.py to drivaernet.py * Rename driveaernet.yaml to drivaernet.yaml * Update drivaernet.md * Update r2_score.py * Update max_ae.py * Update drivaernet_dataset.py * Update drivaernet.md * Update solver.py * Update solver.py * Update drivaernet.py * Create drivaernet * Delete examples/drivaernet * Create drivaernet * Delete examples/drivaernet * Delete examples/DrivAerNet directory * Create drivaernet.py * Create drivaernet.yaml * Update drivaernet.yaml * Update drivaernet.md * Update drivaernet.md * Update regdgcnn.py * Update lr_scheduler.py * Update lr_scheduler.py * Update solver.py * Update lr_scheduler.py * Update examples/drivaernet/drivaernet.py Co-authored-by: HydrogenSulfate <[email protected]> * Update examples/drivaernet/conf/drivaernet.yaml Co-authored-by: HydrogenSulfate <[email protected]> * Update examples/drivaernet/conf/drivaernet.yaml Co-authored-by: HydrogenSulfate <[email protected]> * Update examples/drivaernet/conf/drivaernet.yaml Co-authored-by: HydrogenSulfate <[email protected]> * Update examples/drivaernet/drivaernet.py Co-authored-by: HydrogenSulfate <[email protected]> * Update ppsci/arch/regdgcnn.py Co-authored-by: HydrogenSulfate <[email protected]> * Update ppsci/arch/regdgcnn.py Co-authored-by: HydrogenSulfate <[email protected]> * Update regdgcnn.py * Update lr_scheduler.py * Update drivaernet_dataset.py * Update drivaernet_dataset.py * Update solver.py * Update ppsci/solver/solver.py --------- Co-authored-by: HydrogenSulfate <[email protected]>
1 parent 6652bb8 commit 2706d22

File tree

17 files changed

+1627
-6
lines changed

17 files changed

+1627
-6
lines changed

docs/zh/api/arch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@
3737
- USCNN
3838
- LNO
3939
- TGCN
40+
- RegDGCNN
4041
show_root_heading: true
4142
heading_level: 3

docs/zh/api/data/dataset.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@
3232
- MOlFLOWDataset
3333
- CGCNNDataset
3434
- PEMSDataset
35+
- DrivAerNetDataset
3536
show_root_heading: true

docs/zh/api/lr_scheduler.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
- OneCycleLR
1414
- Piecewise
1515
- Step
16+
- ReduceOnPlateau
1617
show_root_heading: true
1718
heading_level: 3

docs/zh/api/metric.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@
1313
- MeanL2Rel
1414
- MSE
1515
- RMSE
16+
- MaxAE
17+
- R2Score
1618
show_root_heading: true
1719
heading_level: 3

docs/zh/examples/drivaernet.md

Lines changed: 498 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
defaults:
2+
- ppsci_default
3+
- TRAIN: train_default
4+
- TRAIN/ema: ema_default
5+
- TRAIN/swa: swa_default
6+
- EVAL: eval_default
7+
- INFER: infer_default
8+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
9+
- _self_
10+
11+
hydra:
12+
run:
13+
dir: outputs_drivaernet/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
14+
job:
15+
name: ${mode}
16+
chdir: false
17+
callbacks:
18+
init_callback:
19+
_target_: ppsci.utils.callbacks.InitCallback
20+
sweep:
21+
dir: ${hydra.run.dir}
22+
subdir: ./
23+
24+
# general settings
25+
mode: train
26+
seed: 1
27+
output_dir: ${hydra:run.dir}
28+
log_freq: 100
29+
30+
# model settings
31+
MODEL:
32+
input_keys: ["vertices"]
33+
output_keys: ["cd_value"]
34+
weight_keys: ["weight_keys"]
35+
dropout: 0.4
36+
emb_dims: 512
37+
k: 40
38+
output_channels: 1
39+
40+
# training settings
41+
TRAIN:
42+
iters_per_epoch: 2776
43+
num_points: 5000
44+
epochs: 100
45+
num_workers: 32
46+
eval_during_train: True
47+
train_ids_file: "train_design_ids.txt"
48+
eval_ids_file: "val_design_ids.txt"
49+
batch_size: 1
50+
train_fractions: 1
51+
scheduler:
52+
mode: "min"
53+
patience: 20
54+
factor: 0.1
55+
verbose: True
56+
57+
# evaluation settings
58+
EVAL:
59+
num_points: 5000
60+
batch_size: 2
61+
pretrained_model_path: "https://dataset.bj.bcebos.com/PaddleScience/DNNFluid-Car/DrivAer/CdPrediction_DrivAerNet_r2_100epochs_5k_best_model.pdparams"
62+
eval_with_no_grad: True
63+
ids_file: "test_design_ids.txt"
64+
num_workers: 8
65+
66+
# optimizer settings
67+
optimizer:
68+
weight_decay: 0.0001
69+
lr: 0.001
70+
optimizer: 'adam'
71+
72+
ARGS:
73+
# dataset settings
74+
dataset_path: 'data/DrivAerNet_Processed_Point_Clouds_5k_paddle'
75+
aero_coeff: 'data/AeroCoefficients_DrivAerNet_FilteredCorrected.csv'
76+
subset_dir: 'data/subset_dir'

examples/drivaernet/drivaernet.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import warnings
16+
from functools import partial
17+
18+
import hydra
19+
import paddle
20+
from omegaconf import DictConfig
21+
22+
import ppsci
23+
24+
25+
def train(cfg: DictConfig):
26+
# set model
27+
model = ppsci.arch.RegDGCNN(
28+
input_keys=cfg.MODEL.input_keys,
29+
label_keys=cfg.MODEL.output_keys,
30+
weight_keys=cfg.MODEL.weight_keys,
31+
args=cfg.MODEL,
32+
)
33+
34+
train_dataloader_cfg = {
35+
"dataset": {
36+
"name": "DrivAerNetDataset",
37+
"root_dir": cfg.ARGS.dataset_path,
38+
"input_keys": ("vertices",),
39+
"label_keys": ("cd_value",),
40+
"weight_keys": ("weight_keys",),
41+
"subset_dir": cfg.ARGS.subset_dir,
42+
"ids_file": cfg.TRAIN.train_ids_file,
43+
"csv_file": cfg.ARGS.aero_coeff,
44+
"num_points": cfg.TRAIN.num_points,
45+
"train_fractions": cfg.TRAIN.train_fractions,
46+
"mode": cfg.mode,
47+
},
48+
"batch_size": cfg.TRAIN.batch_size,
49+
"num_workers": cfg.TRAIN.num_workers,
50+
}
51+
52+
drivaernet_constraint = ppsci.constraint.SupervisedConstraint(
53+
train_dataloader_cfg,
54+
ppsci.loss.MSELoss("mean"),
55+
name="DrivAerNet_constraint",
56+
)
57+
58+
constraint = {drivaernet_constraint.name: drivaernet_constraint}
59+
60+
valid_dataloader_cfg = {
61+
"dataset": {
62+
"name": "DrivAerNetDataset",
63+
"root_dir": cfg.ARGS.dataset_path,
64+
"input_keys": ("vertices",),
65+
"label_keys": ("cd_value",),
66+
"weight_keys": ("weight_keys",),
67+
"subset_dir": cfg.ARGS.subset_dir,
68+
"ids_file": cfg.TRAIN.eval_ids_file,
69+
"csv_file": cfg.ARGS.aero_coeff,
70+
"num_points": cfg.TRAIN.num_points,
71+
},
72+
"batch_size": cfg.TRAIN.batch_size,
73+
"num_workers": cfg.TRAIN.num_workers,
74+
}
75+
76+
drivaernet_valid = ppsci.validate.SupervisedValidator(
77+
valid_dataloader_cfg,
78+
loss=ppsci.loss.MSELoss("mean"),
79+
metric={"MSE": ppsci.metric.MSE()},
80+
name="DrivAerNet_valid",
81+
)
82+
83+
validator = {drivaernet_valid.name: drivaernet_valid}
84+
85+
# set optimizer
86+
lr_scheduler = ppsci.optimizer.lr_scheduler.ReduceOnPlateau(
87+
epochs=cfg.TRAIN.epochs,
88+
iters_per_epoch=(
89+
cfg.TRAIN.iters_per_epoch
90+
* cfg.TRAIN.train_fractions
91+
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size)
92+
+ 1
93+
),
94+
learning_rate=cfg.optimizer.lr,
95+
mode=cfg.TRAIN.scheduler.mode,
96+
patience=cfg.TRAIN.scheduler.patience,
97+
factor=cfg.TRAIN.scheduler.factor,
98+
verbose=cfg.TRAIN.scheduler.verbose,
99+
)()
100+
101+
optimizer = (
102+
ppsci.optimizer.Adam(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)(
103+
model
104+
)
105+
if cfg.optimizer.optimizer == "adam"
106+
else ppsci.optimizer.SGD(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)(
107+
model
108+
)
109+
)
110+
111+
# initialize solver
112+
solver = ppsci.solver.Solver(
113+
model=model,
114+
iters_per_epoch=(
115+
cfg.TRAIN.iters_per_epoch
116+
* cfg.TRAIN.train_fractions
117+
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size)
118+
+ 1
119+
),
120+
constraint=constraint,
121+
output_dir=cfg.output_dir,
122+
optimizer=optimizer,
123+
lr_scheduler=lr_scheduler,
124+
epochs=cfg.TRAIN.epochs,
125+
validator=validator,
126+
eval_during_train=cfg.TRAIN.eval_during_train,
127+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
128+
)
129+
130+
lr_scheduler.step = partial(lr_scheduler.step, metrics=solver.cur_metric)
131+
solver.lr_scheduler = lr_scheduler
132+
133+
# train model
134+
solver.train()
135+
136+
solver.eval()
137+
138+
139+
def evaluate(cfg: DictConfig):
140+
# set model
141+
model = ppsci.arch.RegDGCNN(
142+
input_keys=cfg.MODEL.input_keys,
143+
label_keys=cfg.MODEL.output_keys,
144+
weight_keys=cfg.MODEL.weight_keys,
145+
args=cfg.MODEL,
146+
)
147+
148+
valid_dataloader_cfg = {
149+
"dataset": {
150+
"name": "DrivAerNetDataset",
151+
"root_dir": cfg.ARGS.dataset_path,
152+
"input_keys": ("vertices",),
153+
"label_keys": ("cd_value",),
154+
"weight_keys": ("weight_keys",),
155+
"subset_dir": cfg.ARGS.subset_dir,
156+
"ids_file": cfg.EVAL.ids_file,
157+
"csv_file": cfg.ARGS.aero_coeff,
158+
"num_points": cfg.EVAL.num_points,
159+
"mode": cfg.mode,
160+
},
161+
"batch_size": cfg.EVAL.batch_size,
162+
"num_workers": cfg.EVAL.num_workers,
163+
}
164+
165+
drivaernet_valid = ppsci.validate.SupervisedValidator(
166+
valid_dataloader_cfg,
167+
loss=ppsci.loss.MSELoss("mean"),
168+
metric={
169+
"MSE": ppsci.metric.MSE(),
170+
"MAE": ppsci.metric.MAE(),
171+
"Max AE": ppsci.metric.MaxAE(),
172+
"R²": ppsci.metric.R2Score(),
173+
},
174+
name="DrivAerNet_valid",
175+
)
176+
177+
validator = {drivaernet_valid.name: drivaernet_valid}
178+
179+
solver = ppsci.solver.Solver(
180+
model=model,
181+
validator=validator,
182+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
183+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
184+
)
185+
186+
# evaluate model
187+
solver.eval()
188+
189+
190+
@hydra.main(version_base=None, config_path="./conf", config_name="drivaernet.yaml")
191+
def main(cfg: DictConfig):
192+
warnings.filterwarnings("ignore")
193+
if cfg.mode == "train":
194+
train(cfg)
195+
elif cfg.mode == "eval":
196+
evaluate(cfg)
197+
else:
198+
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
199+
200+
201+
if __name__ == "__main__":
202+
main()

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ nav:
7575
- tempoGAN: zh/examples/tempoGAN.md
7676
- NSFNet4: zh/examples/nsfnet4.md
7777
- ViV: zh/examples/viv.md
78+
- DrivAerNet: zh/examples/drivaernet.md
7879
- 结构:
7980
- Biharmonic2D: zh/examples/biharmonic2d.md
8081
- Bracket: zh/examples/bracket.md

ppsci/arch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ppsci.arch.velocitygan import VelocityGenerator # isort:skip
6060
from ppsci.arch.moflow_net import MoFlowNet, MoFlowProp # isort:skip
6161
from ppsci.utils import logger # isort:skip
62+
from ppsci.arch.regdgcnn import RegDGCNN # isort:skip
6263

6364
__all__ = [
6465
"MoFlowNet",
@@ -107,6 +108,7 @@
107108
"USCNN",
108109
"VelocityDiscriminator",
109110
"VelocityGenerator",
111+
"RegDGCNN",
110112
]
111113

112114

0 commit comments

Comments
 (0)