Skip to content

Commit 96d45e9

Browse files
[Doc] Fix document and formulations for cvit (#944)
* fix document and formulations * correct doc of Data.batch_transform * update schema for None batch size in INFER config * support batch_size to be set to None where certain data is unnecessary to be batchified * update pre-commit * refine EVAL config schema
1 parent 0ec4019 commit 96d45e9

File tree

8 files changed

+76
-44
lines changed

8 files changed

+76
-44
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
4646

4747
| 问题类型 | 案例名称 | 优化算法 | 模型类型 | 训练方式 | 数据集 | 参考资料 |
4848
|-----|---------|-----|---------|----|---------|---------|
49-
| 一维线性对流问题 | [1D 线性对流](https://paddlescience-docs.readthedocs.io/zh/examples/adv_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://github.com/Zhengyu-Huang/Operator-Learning/tree/main/data) | [Paper](https://arxiv.org/abs/2405.13998) |
50-
| 非定常不可压流体 | [2D 方腔浮力驱动流](https://paddlescience-docs.readthedocs.io/zh/examples/ns_cvit.md) | 数据驱动 | ViT | 监督学习 | [Data](https://huggingface.co/datasets/pdearena/NavierStokes-2D) | [Paper](https://arxiv.org/abs/2405.13998) |
49+
| 一维线性对流问题 | [1D 线性对流](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/adv_cvit/) | 数据驱动 | ViT | 监督学习 | [Data](https://github.com/Zhengyu-Huang/Operator-Learning/tree/main/data) | [Paper](https://arxiv.org/abs/2405.13998) |
50+
| 非定常不可压流体 | [2D 方腔浮力驱动流](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/ns_cvit/) | 数据驱动 | ViT | 监督学习 | [Data](https://huggingface.co/datasets/pdearena/NavierStokes-2D) | [Paper](https://arxiv.org/abs/2405.13998) |
5151
| 定常不可压流体 | [Re3200 2D 定常方腔流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/ldc2d_steady) | 机理驱动 | MLP | 无监督学习 | - | |
5252
| 定常不可压流体 | [2D 达西流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/darcy2d) | 机理驱动 | MLP | 无监督学习 | - | |
5353
| 定常不可压流体 | [2D 管道流](https://paddlescience-docs.readthedocs.io/zh/latest/zh/examples/labelfree_DNN_surrogate) | 机理驱动 | MLP | 无监督学习 | - | [Paper](https://arxiv.org/abs/1906.02382) |

deploy/python_infer/pinn_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def predict(
159159
ed = min(num_samples, batch_id * batch_size)
160160
batch_input_dict = {key: input_dict[key][st:ed] for key in input_dict}
161161
else:
162-
batch_input_dict = {key: input_dict[key] for key in input_dict}
162+
batch_input_dict = {**input_dict}
163163

164164
# send batch input data to input handle(s)
165165
if self.engine != "onnx":
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Data.batch_transform(批预处理) 模块
22

3-
::: ppsci.data.process.transform
3+
::: ppsci.data.process.batch_transform
44
handler: python
55
options:
66
members:
77
- build_batch_transforms
8+
- FunctionalBatchTransform
89
- default_collate_fn
910
show_root_heading: true

docs/zh/development.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ PaddleScience 是一个开源的代码库,由多人共同参与开发,因此
926926
PaddleScience 使用了包括 [isort](https://github.com/PyCQA/isort#installing-isort)、[black](https://github.com/psf/black) 等自动化代码检查、格式化插件,
927927
让 commit 的代码遵循 python [PEP8](https://pep8.org/) 代码风格规范。
928928
929-
因此在 commit 您的代码之前,请务必先执行以下命令安装 `pre-commit`,否则提交的 PR 会被 code-style 检测到代码未格式化而无法合入。
929+
因此在 commit 您的代码之前,请务必先在 `PaddleScience/` 目录下执行以下命令安装 `pre-commit`,否则提交的 PR 会被 code-style 检测到代码未格式化而无法合入。
930930
931931
``` sh
932932
pip install pre-commit

docs/zh/examples/adv_cvit.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,22 @@ CVit 作为一种算子学习模型,以输入函数 $u$、函数 $s$ 的查询
4949
本问题求解如下方程:
5050

5151
Formulation The 1D advection equation in $\Omega=[0,1)$ is
52+
5253
$$
5354
\begin{aligned}
5455
& \frac{\partial u}{\partial t}+c \frac{\partial u}{\partial x}=0 \quad x \in \Omega, \\
5556
& u(0)=u_0
5657
\end{aligned}
5758
$$
59+
5860
where $c=1$ is the constant advection speed, and periodic boundary conditions are imposed. We are interested in the map from the initial $u_0$ to solution $u(\cdot, T)$ at $T=0.5$. The initial condition $u_0$ is assumed to be
61+
5962
$$
6063
u_0=-1+2 \mathbb{1}\left\{\tilde{u_0} \geq 0\right\}
6164
$$
65+
6266
where $\widetilde{u_0}$ a centered Gaussian
67+
6368
$$
6469
\widetilde{u_0} \sim \mathbb{N}(0, \mathrm{C}) \quad \text { and } \quad \mathrm{C}=\left(-\Delta+\tau^2\right)^{-d} \text {; }
6570
$$

docs/zh/examples/ns_cvit.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ CVit 作为一种算子学习模型,以输入函数 $u$、函数 $s$ 的查询
5454
本问题基于固定方腔的不可压 buoyancy-driven flow 即方腔内的浮力驱动流动问题,求解如下方程:
5555

5656
Formulation We consider the vorticity-stream $(\omega-\psi)$ formulation of the incompressible Navier-Stokes equations on a two-dimensional periodic domain, $D=D_u=D_v=[0,2 \pi]^2$ :
57+
5758
$$
5859
\begin{aligned}
5960
& \frac{\partial \omega}{\partial t}+(v \cdot \nabla) \omega-v \Delta \omega=f^{\prime} \\

ppsci/solver/solver.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def predict(
710710
self,
711711
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
712712
expr_dict: Optional[Dict[str, Callable]] = None,
713-
batch_size: int = 64,
713+
batch_size: Optional[int] = 64,
714714
no_grad: bool = True,
715715
return_numpy: bool = False,
716716
) -> Dict[str, Union[paddle.Tensor, np.ndarray]]:
@@ -720,7 +720,9 @@ def predict(
720720
input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict.
721721
expr_dict (Optional[Dict[str, Callable]]): Expression dict, which guide to
722722
compute equation variable with callable function. Defaults to None.
723-
batch_size (int, optional): Predicting by batch size. Defaults to 64.
723+
batch_size (Optional[int]): Predicting by batch size. If None, data in
724+
`input_dict` will be used directly for inference without any batch slicing.
725+
Defaults to 64.
724726
no_grad (bool): Whether set stop_gradient=True for entire prediction, mainly
725727
for memory-efficiency. Defaults to True.
726728
return_numpy (bool): Whether convert result from Tensor to numpy ndarray.
@@ -773,26 +775,32 @@ def predict(
773775
if self.world_size > 1
774776
else input_dict
775777
)
776-
local_batch_num = (local_num_samples_pad + (batch_size - 1)) // batch_size
778+
local_batch_num = (
779+
(local_num_samples_pad + (batch_size - 1)) // batch_size
780+
if batch_size is not None
781+
else 1
782+
)
777783

778784
pred_dict = misc.Prettydefaultdict(list)
779785
with self.no_grad_context_manager(no_grad), self.no_sync_context_manager(
780786
self.world_size > 1, self.model
781787
):
782788
for batch_id in range(local_batch_num):
783-
batch_input_dict = {}
784-
st = batch_id * batch_size
785-
ed = min(local_num_samples_pad, (batch_id + 1) * batch_size)
786-
787789
# prepare batch input dict
788-
for key in local_input_dict:
789-
if not paddle.is_tensor(local_input_dict[key]):
790-
batch_input_dict[key] = paddle.to_tensor(
791-
local_input_dict[key][st:ed], paddle.get_default_dtype()
792-
)
793-
else:
794-
batch_input_dict[key] = local_input_dict[key][st:ed]
795-
batch_input_dict[key].stop_gradient = no_grad
790+
batch_input_dict = {}
791+
if batch_size is not None:
792+
st = batch_id * batch_size
793+
ed = min(local_num_samples_pad, (batch_id + 1) * batch_size)
794+
for key in local_input_dict:
795+
if not paddle.is_tensor(local_input_dict[key]):
796+
batch_input_dict[key] = paddle.to_tensor(
797+
local_input_dict[key][st:ed], paddle.get_default_dtype()
798+
)
799+
else:
800+
batch_input_dict[key] = local_input_dict[key][st:ed]
801+
batch_input_dict[key].stop_gradient = no_grad
802+
else:
803+
batch_input_dict = {**local_input_dict}
796804

797805
# forward
798806
with self.autocast_context_manager(self.use_amp, self.amp_level):

ppsci/utils/config.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ class EMAConfig(BaseModel):
4242
def decay_check(cls, v):
4343
if v <= 0 or v >= 1:
4444
raise ValueError(
45-
f"'decay' should be in (0, 1) when is type of float, but got {v}"
45+
f"'ema.decay' should be in (0, 1) when is type of float, but got {v}"
4646
)
4747
return v
4848

4949
@field_validator("avg_freq")
5050
def avg_freq_check(cls, v):
5151
if v <= 0:
5252
raise ValueError(
53-
"'avg_freq' should be a positive integer when is type of int, "
53+
"'ema.avg_freq' should be a positive integer when is type of int, "
5454
f"but got {v}"
5555
)
5656
return v
@@ -63,15 +63,17 @@ class SWAConfig(BaseModel):
6363
@field_validator("avg_range")
6464
def avg_range_check(cls, v, info: ValidationInfo):
6565
if isinstance(v, tuple) and v[0] > v[1]:
66-
raise ValueError(f"'avg_range' should be a valid range, but got {v}.")
66+
raise ValueError(
67+
f"'swa.avg_range' should be a valid range, but got {v}."
68+
)
6769
if isinstance(v, tuple) and v[0] < 0:
6870
raise ValueError(
69-
"The start epoch of 'avg_range' should be a non-negtive integer"
71+
"The start epoch of 'swa.avg_range' should be a non-negtive integer"
7072
f" , but got {v[0]}."
7173
)
7274
if isinstance(v, tuple) and v[1] > info.data["epochs"]:
7375
raise ValueError(
74-
"The end epoch of 'avg_range' should not be lager than "
76+
"The end epoch of 'swa.avg_range' should not be lager than "
7577
f"'epochs'({info.data['epochs']}), but got {v[1]}."
7678
)
7779
return v
@@ -80,7 +82,7 @@ def avg_range_check(cls, v, info: ValidationInfo):
8082
def avg_freq_check(cls, v):
8183
if v <= 0:
8284
raise ValueError(
83-
"'avg_freq' should be a positive integer when is type of int, "
85+
"'swa.avg_freq' should be a positive integer when is type of int, "
8486
f"but got {v}"
8587
)
8688
return v
@@ -107,7 +109,7 @@ class TrainConfig(BaseModel):
107109
def epochs_check(cls, v):
108110
if v <= 0:
109111
raise ValueError(
110-
"'epochs' should be a positive integer when is type of int, "
112+
"'TRAIN.epochs' should be a positive integer when is type of int, "
111113
f"but got {v}"
112114
)
113115
return v
@@ -116,7 +118,7 @@ def epochs_check(cls, v):
116118
def iters_per_epoch_check(cls, v):
117119
if v <= 0:
118120
raise ValueError(
119-
"'iters_per_epoch' should be a positive integer when is type of int"
121+
"'TRAIN.iters_per_epoch' should be a positive integer when is type of int"
120122
f", but got {v}"
121123
)
122124
return v
@@ -125,7 +127,7 @@ def iters_per_epoch_check(cls, v):
125127
def update_freq_check(cls, v):
126128
if v <= 0:
127129
raise ValueError(
128-
"'update_freq' should be a positive integer when is type of int"
130+
"'TRAIN.update_freq' should be a positive integer when is type of int"
129131
f", but got {v}"
130132
)
131133
return v
@@ -134,7 +136,7 @@ def update_freq_check(cls, v):
134136
def save_freq_check(cls, v):
135137
if v < 0:
136138
raise ValueError(
137-
"'save_freq' should be a non-negtive integer when is type of int"
139+
"'TRAIN.save_freq' should be a non-negtive integer when is type of int"
138140
f", but got {v}"
139141
)
140142
return v
@@ -144,8 +146,8 @@ def start_eval_epoch_check(cls, v, info: ValidationInfo):
144146
if info.data["eval_during_train"]:
145147
if v <= 0:
146148
raise ValueError(
147-
f"'start_eval_epoch' should be a positive integer when "
148-
f"'eval_during_train' is True, but got {v}"
149+
f"'TRAIN.start_eval_epoch' should be a positive integer when "
150+
f"'TRAIN.eval_during_train' is True, but got {v}"
149151
)
150152
return v
151153

@@ -154,8 +156,8 @@ def eval_freq_check(cls, v, info: ValidationInfo):
154156
if info.data["eval_during_train"]:
155157
if v <= 0:
156158
raise ValueError(
157-
f"'eval_freq' should be a positive integer when "
158-
f"'eval_during_train' is True, but got {v}"
159+
f"'TRAIN.eval_freq' should be a positive integer when "
160+
f"'TRAIN.eval_during_train' is True, but got {v}"
159161
)
160162
return v
161163

@@ -176,6 +178,15 @@ class EvalConfig(BaseModel):
176178
pretrained_model_path: Optional[str] = None
177179
eval_with_no_grad: bool = False
178180
compute_metric_by_batch: bool = False
181+
batch_size: Optional[int] = 256
182+
183+
@field_validator("batch_size")
184+
def batch_size_check(cls, v):
185+
if isinstance(v, int) and v <= 0:
186+
raise ValueError(
187+
f"'EVAL.batch_size' should be greater than 0 or None, but got {v}"
188+
)
189+
return v
179190

180191
class InferConfig(BaseModel):
181192
"""
@@ -203,12 +214,12 @@ class InferConfig(BaseModel):
203214
def engine_check(cls, v, info: ValidationInfo):
204215
if v == "tensorrt" and info.data["device"] != "gpu":
205216
raise ValueError(
206-
"'device' should be 'gpu' when 'engine' is 'tensorrt', "
217+
"'INFER.device' should be 'gpu' when 'INFER.engine' is 'tensorrt', "
207218
f"but got '{info.data['device']}'"
208219
)
209220
if v == "mkldnn" and info.data["device"] != "cpu":
210221
raise ValueError(
211-
"'device' should be 'cpu' when 'engine' is 'mkldnn', "
222+
"'INFER.device' should be 'cpu' when 'INFER.engine' is 'mkldnn', "
212223
f"but got '{info.data['device']}'"
213224
)
214225

@@ -218,46 +229,50 @@ def engine_check(cls, v, info: ValidationInfo):
218229
def min_subgraph_size_check(cls, v):
219230
if v <= 0:
220231
raise ValueError(
221-
"'min_subgraph_size' should be greater than 0, " f"but got {v}"
232+
"'INFER.min_subgraph_size' should be greater than 0, "
233+
f"but got {v}"
222234
)
223235
return v
224236

225237
@field_validator("gpu_mem")
226238
def gpu_mem_check(cls, v):
227239
if v <= 0:
228-
raise ValueError("'gpu_mem' should be greater than 0, " f"but got {v}")
240+
raise ValueError(
241+
"'INFER.gpu_mem' should be greater than 0, " f"but got {v}"
242+
)
229243
return v
230244

231245
@field_validator("gpu_id")
232246
def gpu_id_check(cls, v):
233247
if v < 0:
234248
raise ValueError(
235-
"'gpu_id' should be greater than or equal to 0, " f"but got {v}"
249+
"'INFER.gpu_id' should be greater than or equal to 0, "
250+
f"but got {v}"
236251
)
237252
return v
238253

239254
@field_validator("max_batch_size")
240255
def max_batch_size_check(cls, v):
241256
if v <= 0:
242257
raise ValueError(
243-
"'max_batch_size' should be greater than 0, " f"but got {v}"
258+
"'INFER.max_batch_size' should be greater than 0, " f"but got {v}"
244259
)
245260
return v
246261

247262
@field_validator("num_cpu_threads")
248263
def num_cpu_threads_check(cls, v):
249264
if v < 0:
250265
raise ValueError(
251-
"'num_cpu_threads' should be greater than or equal to 0, "
266+
"'INFER.num_cpu_threads' should be greater than or equal to 0, "
252267
f"but got {v}"
253268
)
254269
return v
255270

256271
@field_validator("batch_size")
257272
def batch_size_check(cls, v):
258-
if v <= 0:
273+
if isinstance(v, int) and v <= 0:
259274
raise ValueError(
260-
"'batch_size' should be greater than 0, " f"but got {v}"
275+
f"'INFER.batch_size' should be greater than 0 or None, but got {v}"
261276
)
262277
return v
263278

@@ -326,7 +341,8 @@ def use_wandb_check(cls, v, info: ValidationInfo):
326341
- TRAIN/swa: swa_default <-- 'swa_default' used here
327342
- EVAL: eval_default <-- 'eval_default' used here
328343
- INFER: infer_default <-- 'infer_default' used here
329-
- _self_
344+
- _self_ <-- config defined in current yaml
345+
330346
mode: train
331347
seed: 42
332348
...
@@ -384,6 +400,7 @@ def use_wandb_check(cls, v, info: ValidationInfo):
384400
"EVAL.pretrained_model_path",
385401
"EVAL.eval_with_no_grad",
386402
"EVAL.compute_metric_by_batch",
403+
"EVAL.batch_size",
387404
"INFER.pretrained_model_path",
388405
"INFER.export_path",
389406
"INFER.pdmodel_path",

0 commit comments

Comments
 (0)