Skip to content

Commit 3a1d360

Browse files
[Example] Add adv_cvit and ns_cvit (#939)
* update develop mkdocs * allow alias for mike * update CViT code(WIP) * update CViT code(WIP) * update validate code * update code * update code * refine code * refine docs * update docs * update export&inference code * update ns cvit code(WIP, not aligned) * update reprod code * rename block name according to their class * add more config yamls * fix data/__init__ * fix * refine code and add more annotations * update code * update TRT steps * change pos/time embedding from buffer to trainable parameters * use interpolation for spatial_dims * remove interpolation * update config * fix std of normal initializer * update outputfile * add einops into req * restore l2_rel * refine eval checking and logging * update example code of FunctionalBatchTransform * update pretrained url and plot code * add cvit doc * update adv_cvit doc * rename title for adv_cvit.md * update docs * update docs * fix zh/examples/extformer_moe.md * update docs * refine code
1 parent c3c239d commit 3a1d360

40 files changed

+3312
-170
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
4646

4747
| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
4848
|-----|---------|-----|---------|----|---------|---------|
49-
| 定常不可压流体 | [2D 定常方腔流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/ldc2d_steady) | 机理驱动 | MLP | 无监督学习 | - | |
49+
| 一维线性对流问题 | [1D 线性对流](https://paddlescience-docs.readthedocs.io/zh/examples/adv_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://github.com/Zhengyu-Huang/Operator-Learning/tree/main/data) | [Paper](https://arxiv.org/abs/2405.13998) |
50+
| 非定常不可压流体 | [2D 方腔浮力驱动流](https://paddlescience-docs.readthedocs.io/zh/examples/ns_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://huggingface.co/datasets/pdearena/NavierStokes-2D) | [Paper](https://arxiv.org/abs/2405.13998) |
51+
| 定常不可压流体 | [Re3200 2D 定常方腔流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/ldc2d_steady) | 机理驱动 | MLP | 无监督学习 | - | |
5052
| 定常不可压流体 | [2D 达西流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/darcy2d) | 机理驱动 | MLP | 无监督学习 | - | |
5153
| 定常不可压流体 | [2D 管道流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/labelfree_DNN_surrogate) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/1906.02382) |
5254
| 定常不可压流体 | [3D 血管瘤](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/aneurysm) | 机理驱动 | MLP | 无监督学习 | [Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/aneurysm/aneurysm_dataset.tar) | [Project](https://docs.nvidia.com/deeplearning/modulus/modulus-v2209/user_guide/intermediate/adding_stl_files.html)|
@@ -102,6 +104,7 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
102104
<!-- --8<-- [start:update] -->
103105
## 🕘最近更新
104106

107+
- 添加 CVit(基于 Advection 方程和 N-S 方程求解) [CVit(Navier-Stokes)](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/ns_cvit/)[CVit(Advection)](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/adv_cvit/)
105108
- 添加 PirateNet(基于 Allen-cahn 方程和 N-S 方程求解) [Allen-Cahn](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/allen_cahn/)[LDC2D(Re3200)](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/ldc2d_steady/)
106109
- 基于 PaddleScience 的快速热仿真方法 [A fast general thermal simulation model based on MultiBranch Physics-Informed deep operator neural network](https://pubs.aip.org/aip/pof/article-abstract/36/3/037142/3277890/A-fast-general-thermal-simulation-model-based-on?redirectedFrom=fulltext) 被 Physics of Fluids 2024 接受。
107110
- 添加多目标优化算法 [Relobralo](https://paddlescience-docs.readthedocs.io/zh/latest/zh/api/loss/mtl/#ppsci.loss.mtl.Relobralo)

deploy/python_infer/pinn_predictor.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Dict
1616
from typing import List
17+
from typing import Optional
1718
from typing import Union
1819

1920
import numpy as np
@@ -106,21 +107,22 @@ def __init__(
106107
def predict(
107108
self,
108109
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
109-
batch_size: int = 64,
110+
batch_size: Optional[int] = 64,
110111
) -> Dict[str, np.ndarray]:
111112
"""
112113
Predicts the output of the model for the given input.
113114
114115
Args:
115116
input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]):
116117
A dictionary containing the input data.
117-
batch_size (int, optional): The batch size to use for prediction.
118-
Defaults to 64.
118+
batch_size (Optional[int]): The batch size to use for prediction.
119+
If None, input will be directly sent to the model
120+
without batch slicing. Defaults to 64.
119121
120122
Returns:
121123
Dict[str, np.ndarray]: A dictionary containing the predicted output.
122124
"""
123-
if batch_size > self.max_batch_size:
125+
if batch_size and batch_size > self.max_batch_size:
124126
logger.warning(
125127
f"batch_size({batch_size}) is larger than "
126128
f"max_batch_size({self.max_batch_size}), which may occur error."
@@ -143,7 +145,7 @@ def predict(
143145
]
144146

145147
num_samples = len(next(iter(input_dict.values())))
146-
batch_num = (num_samples + (batch_size - 1)) // batch_size
148+
batch_num = (num_samples + (batch_size - 1)) // batch_size if batch_size else 1
147149
pred_dict = misc.Prettydefaultdict(list)
148150

149151
# inference by batch
@@ -152,9 +154,12 @@ def predict(
152154
logger.info(f"Predicting batch {batch_id}/{batch_num}")
153155

154156
# prepare batch input dict
155-
st = (batch_id - 1) * batch_size
156-
ed = min(num_samples, batch_id * batch_size)
157-
batch_input_dict = {key: input_dict[key][st:ed] for key in input_dict}
157+
if batch_size:
158+
st = (batch_id - 1) * batch_size
159+
ed = min(num_samples, batch_id * batch_size)
160+
batch_input_dict = {key: input_dict[key][st:ed] for key in input_dict}
161+
else:
162+
batch_input_dict = {key: input_dict[key] for key in input_dict}
158163

159164
# send batch input data to input handle(s)
160165
if self.engine != "onnx":

docs/index.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@
9191

9292
| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
9393
|-----|---------|-----|---------|----|---------|---------|
94-
| 定常不可压流体 | [2D 定常方腔流](./zh/examples/ldc2d_steady.md) | 机理驱动 | MLP | 无监督学习 | - | |
94+
| 一维线性对流问题 | [1D 线性对流](./zh/examples/adv_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://github.com/Zhengyu-Huang/Operator-Learning/tree/main/data) | [Paper](https://arxiv.org/abs/2405.13998) |
95+
| 非定常不可压流体 | [2D 方腔浮力驱动流](./zh/examples/ns_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://huggingface.co/datasets/pdearena/NavierStokes-2D) | [Paper](https://arxiv.org/abs/2405.13998) |
96+
| 定常不可压流体 | [Re3200 2D 定常方腔流](./zh/examples/ldc2d_steady.md) | 机理驱动 | MLP | 无监督学习 | - | |
9597
| 定常不可压流体 | [2D 达西流](./zh/examples/darcy2d.md) | 机理驱动 | MLP | 无监督学习 | - | |
9698
| 定常不可压流体 | [2D 管道流](./zh/examples/labelfree_DNN_surrogate.md) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/1906.02382) |
9799
| 定常不可压流体 | [3D 血管瘤](./zh/examples/aneurysm.md) | 机理驱动 | MLP | 无监督学习 | [Data](https://paddle-org.bj.bcebos.com/paddlescience/datasets/aneurysm/aneurysm_dataset.tar) | [Project](https://docs.nvidia.com/deeplearning/modulus/modulus-v2209/user_guide/intermediate/adding_stl_files.html)|

docs/zh/api/arch.md

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,32 @@
44
handler: python
55
options:
66
members:
7-
- Arch
7+
- AFNONet
88
- AMGNet
9-
- MLP
10-
- ModifiedMLP
11-
- PirateNet
9+
- Arch
10+
- AutoEncoder
11+
- ChipDeepONets
12+
- CuboidTransformer
13+
- CVit1D
14+
- CylinderEmbedding
1215
- DeepONet
1316
- DeepPhyLSTM
14-
- LorenzEmbedding
15-
- RosslerEmbedding
16-
- CylinderEmbedding
17-
- Generator
17+
- DGMR
1818
- Discriminator
19-
- PhysformerGPT2
19+
- ExtFormerMoECuboid
20+
- Generator
21+
- HEDeepONets
22+
- LorenzEmbedding
23+
- MLP
2024
- ModelList
21-
- AFNONet
22-
- PrecipNet
23-
- PhyCRNet
24-
- UNetEx
25-
- USCNN
25+
- ModifiedMLP
2626
- NowcastNet
27-
- HEDeepONets
28-
- DGMR
29-
- ChipDeepONets
30-
- AutoEncoder
31-
- CuboidTransformer
32-
- ExtFormerMoECuboid
3327
- SFNONet
34-
- UNONet
3528
- TFNO1dNet
3629
- TFNO2dNet
3730
- TFNO3dNet
31+
- UNetEx
32+
- UNONet
33+
- USCNN
3834
show_root_heading: true
3935
heading_level: 3

docs/zh/api/equation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
- Laplace
1313
- LinearElasticity
1414
- NavierStokes
15+
- NLSMB
1516
- NormalDotVec
1617
- Poisson
1718
- Vibration
1819
- Volterra
19-
- NLSMB
2020
show_root_heading: true
2121
heading_level: 3

docs/zh/api/experimental.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
- bessel_i1e
1515
- fractional_diff
1616
- gaussian_integrate
17-
- trapezoid_integrate
1817
- montecarlo_integrate
18+
- trapezoid_integrate
1919
show_root_heading: true
2020
heading_level: 3

docs/zh/api/geometry.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
options:
66
members:
77
- Geometry
8-
- Interval
9-
- Disk
10-
- Polygon
11-
- Rectangle
12-
- Triangle
138
- Cuboid
14-
- Sphere
9+
- Disk
1510
- Hypercube
1611
- Hypersphere
12+
- Interval
1713
- Mesh
1814
- PointCloud
15+
- Polygon
16+
- Rectangle
17+
- Sphere
1918
- TimeDomain
2019
- TimeXGeometry
20+
- Triangle
2121
show_root_heading: true
2222
heading_level: 3

docs/zh/api/lr_scheduler.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
handler: python
55
options:
66
members:
7-
- Linear
87
- Cosine
9-
- Step
10-
- Piecewise
11-
- MultiStepDecay
12-
- ExponentialDecay
138
- CosineWarmRestarts
9+
- ExponentialDecay
10+
- Linear
11+
- MultiStepDecay
1412
- OneCycleLR
13+
- Piecewise
14+
- Step
1515
show_root_heading: true
1616
heading_level: 3

docs/zh/api/metric.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
members:
77
- Metric
88
- FunctionalMetric
9-
- MAE
10-
- MSE
11-
- RMSE
129
- L2Rel
13-
- MeanL2Rel
1410
- LatitudeWeightedACC
1511
- LatitudeWeightedRMSE
12+
- MAE
13+
- MeanL2Rel
14+
- MSE
15+
- RMSE
1616
show_root_heading: true
1717
heading_level: 3

docs/zh/api/optimizer.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
handler: python
55
options:
66
members:
7-
- SGD
8-
- Momentum
97
- Adam
108
- AdamW
11-
- RMSProp
129
- LBFGS
10+
- Momentum
1311
- OptimizerList
12+
- RMSProp
13+
- SGD
1414
show_root_heading: true
1515
heading_level: 3

0 commit comments

Comments
 (0)