Skip to content

Commit 6c720a4

Browse files
[Refine&Doc] Fix ijcai url and refien eval code (#970)
* update ijcai2024 url * support custom collate_fn via setting 'collate_fn' in dataloader_cfg * update batch_size computation in eval * update code * fix for CI * add content in recent updating * add spinn aistudio url
1 parent d81b49f commit 6c720a4

File tree

6 files changed

+19
-9
lines changed

6 files changed

+19
-9
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
🔥 [IJCAI 2024: 任意三维几何外形车辆的风阻快速预测竞赛](https://competition.atomgit.com/competitionInfo?id=7f3f276465e9e845fd3a811d2d6925b5),track A, B, C 代码:
1717

18-
- [paddle实现](../jointContribution/IJCAI_2024/README.md)
18+
- [paddle实现](./jointContribution/IJCAI_2024/README.md)
1919
- [pytorch实现](https://competition.atomgit.com/competitionInfo?id=7f3f276465e9e845fd3a811d2d6925b5)(点击**排行榜**可查看各个赛道前10名的代码)
2020

2121
<!-- --8<-- [start:description] -->
@@ -108,6 +108,8 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
108108
<!-- --8<-- [start:update] -->
109109
## 🕘最近更新
110110

111+
- 添加 [IJCAI 2024: 任意三维几何外形车辆的风阻快速预测竞赛](https://competition.atomgit.com/competitionInfo?id=7f3f276465e9e845fd3a811d2d6925b5),track A, B, C 的 paddle/pytorch 代码链接。
112+
- 添加 SPINN(基于 Helmholtz3D 方程求解) [helmholtz3d](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/spinn/)
111113
- 添加 CVit(基于 Advection 方程和 N-S 方程求解) [CVit(Navier-Stokes)](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/ns_cvit/)[CVit(Advection)](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/adv_cvit/)
112114
- 添加 PirateNet(基于 Allen-cahn 方程和 N-S 方程求解) [Allen-Cahn](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/allen_cahn/)[LDC2D(Re3200)](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/ldc2d_steady/)
113115
- 基于 PaddleScience 的快速热仿真方法 [A fast general thermal simulation model based on MultiBranch Physics-Informed deep operator neural network](https://pubs.aip.org/aip/pof/article-abstract/36/3/037142/3277890/A-fast-general-thermal-simulation-model-based-on?redirectedFrom=fulltext) 被 Physics of Fluids 2024 接受。

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
🔥 [IJCAI 2024: 任意三维几何外形车辆的风阻快速预测竞赛](https://competition.atomgit.com/competitionInfo?id=7f3f276465e9e845fd3a811d2d6925b5),track A, B, C 代码:
88

9-
- [paddle实现](../jointContribution/IJCAI_2024/README.md)
9+
- [paddle实现](https://github.com/PaddlePaddle/PaddleScience/tree/develop/jointContribution/IJCAI_2024)
1010
- [pytorch实现](https://competition.atomgit.com/competitionInfo?id=7f3f276465e9e845fd3a811d2d6925b5)(点击**排行榜**可查看各个赛道前10名的代码)
1111

1212
<style>

docs/zh/examples/spinn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPINN(helmholtz3d)
22

3-
<!-- <a href="https://aistudio.baidu.com/projectdetail/8219967" class="md-button md-button--primary" style>AI Studio快速体验</a> -->
3+
<a href="https://aistudio.baidu.com/projectdetail/8219967" class="md-button md-button--primary" style>AI Studio快速体验</a>
44

55
=== "模型训练命令"
66

ppsci/data/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import copy
1616
import random
1717
from functools import partial
18+
from typing import Callable
19+
from typing import Optional
1820

1921
import numpy as np
2022
import paddle.distributed as dist
@@ -101,9 +103,11 @@ def build_dataloader(_dataset, cfg):
101103

102104
# build collate_fn if specified
103105
batch_transforms_cfg = cfg.pop("batch_transforms", None)
104-
collate_fn = None
106+
collate_fn: Optional[Callable] = cfg.pop("collate_fn", None)
105107
if isinstance(batch_transforms_cfg, (list, tuple)):
106-
collate_fn = batch_transform.build_batch_transforms(batch_transforms_cfg)
108+
collate_fn = batch_transform.build_batch_transforms(
109+
batch_transforms_cfg, collate_fn
110+
)
107111

108112
# build init function
109113
_DEFAULT_NUM_WORKERS = 1

ppsci/data/process/batch_transform/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any
2020
from typing import Callable
2121
from typing import List
22+
from typing import Optional
2223

2324
import numpy as np
2425
import paddle
@@ -118,15 +119,17 @@ def build_transforms(cfg):
118119
return transform.Compose(transform_list)
119120

120121

121-
def build_batch_transforms(cfg):
122+
def build_batch_transforms(cfg, collate_fn: Optional[Callable]):
122123
cfg = copy.deepcopy(cfg)
123124
batch_transforms: Callable[[List[Any]], List[Any]] = build_transforms(cfg)
125+
if collate_fn is None:
126+
collate_fn = default_collate_fn
124127

125128
def collate_fn_batch_transforms(batch: List[Any]):
126129
# apply batch transform on separate samples
127130
batch = batch_transforms(batch)
128131

129132
# then collate separate samples into batched data
130-
return default_collate_fn(batch)
133+
return collate_fn(batch)
131134

132135
return collate_fn_batch_transforms

ppsci/solver/eval.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from paddle import io
2626

2727
from ppsci.solver import printer
28+
from ppsci.solver.train import _compute_batch_size
2829
from ppsci.utils import misc
2930

3031
if TYPE_CHECKING:
@@ -128,7 +129,7 @@ def _eval_by_dataset(
128129
batch_cost = time.perf_counter() - batch_tic
129130
solver.eval_time_info["reader_cost"].update(reader_cost)
130131
solver.eval_time_info["batch_cost"].update(batch_cost)
131-
batch_size = next(iter(input_dict.values())).shape[0]
132+
batch_size = _compute_batch_size(input_dict)
132133
printer.update_eval_loss(solver, loss_dict, batch_size)
133134
if (
134135
iter_id == 1
@@ -216,7 +217,7 @@ def _eval_by_batch(
216217
input_dict, label_dict, weight_dict = batch
217218
reader_cost = time.perf_counter() - reader_tic
218219

219-
batch_size = next(iter(input_dict.values())).shape[0]
220+
batch_size = _compute_batch_size(input_dict)
220221
for v in input_dict.values():
221222
if hasattr(v, "stop_gradient"):
222223
v.stop_gradient = False

0 commit comments

Comments
 (0)