Skip to content

Commit 154c01c

Browse files
[API] Add chamfer loss (#871)
* add chamfer loss * update development.md * update example of chamfer loss docstring * fix doctest for ChamferLoss * support batch comupute for chamferloss
1 parent 9d84962 commit 154c01c

File tree

8 files changed

+369
-5
lines changed

8 files changed

+369
-5
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,9 @@ PaddleScience 项目欢迎并依赖开发人员和开源社区中的用户,会
220220
旨在鼓励更多的开发者参与到飞桨科学计算社区的开源建设中,帮助社区修复 bug 或贡献 feature,加入开源、共建飞桨。了解编程基本知识的入门用户即可参与,活动进行中:
221221
[PaddleScience 快乐开源活动表单](https://github.com/PaddlePaddle/PaddleScience/issues/379)
222222

223-
- 🔥第五期黑客松
223+
- 🔥第六期黑客松
224224

225-
面向全球开发者的深度学习领域编程活动,鼓励开发者了解与参与飞桨深度学习开源项目与文心大模型开发实践。活动进行中:[【PaddlePaddle Hackathon 5th】开源贡献个人挑战赛](https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_5th/%E3%80%90PaddlePaddle%20Hackathon%205th%E3%80%91%E5%BC%80%E6%BA%90%E8%B4%A1%E7%8C%AE%E4%B8%AA%E4%BA%BA%E6%8C%91%E6%88%98%E8%B5%9B%E7%A7%91%E5%AD%A6%E8%AE%A1%E7%AE%97%E4%BB%BB%E5%8A%A1%E5%90%88%E9%9B%86.md#%E4%BB%BB%E5%8A%A1%E5%BC%80%E5%8F%91%E6%B5%81%E7%A8%8B%E4%B8%8E%E9%AA%8C%E6%94%B6%E6%A0%87%E5%87%86)
225+
面向全球开发者的深度学习领域编程活动,鼓励开发者了解与参与飞桨深度学习开源项目与文心大模型开发实践。活动进行中:[【PaddlePaddle Hackathon 5th】开源贡献个人挑战赛](https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_6th/%E3%80%90Hackathon%206th%E3%80%91%E5%BC%80%E6%BA%90%E8%B4%A1%E7%8C%AE%E4%B8%AA%E4%BA%BA%E6%8C%91%E6%88%98%E8%B5%9B%E7%A7%91%E5%AD%A6%E8%AE%A1%E7%AE%97%E4%BB%BB%E5%8A%A1%E5%90%88%E9%9B%86.md)
226226
<!-- --8<-- [end:contribution] -->
227227

228228
<!-- --8<-- [start:collaboration] -->

docs/zh/api/loss/loss.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- L2RelLoss
1212
- MAELoss
1313
- MSELoss
14+
- ChamferLoss
1415
- CausalMSELoss
1516
- MSELossWithL2Decay
1617
- IntegralLoss

docs/zh/development.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh")
116116
117117
``` py
118118
--8<--
119-
ppsci/arch/mlp.py:86:151
119+
ppsci/arch/mlp.py:139:279
120120
--8<--
121121
```
122122
123123
=== "MLP.forward"
124124
125125
``` py
126126
--8<--
127-
ppsci/arch/mlp.py:153:180
127+
ppsci/arch/mlp.py:298:315
128128
--8<--
129129
```
130130

docs/zh/examples/allen_cahn.md

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Allen-Cahn
2+
3+
<!-- <a href="TODO" class="md-button md-button--primary" style>AI Studio快速体验</a> -->
4+
5+
=== "模型训练命令"
6+
7+
``` sh
8+
python allen_cahn_default.py
9+
```
10+
11+
=== "模型评估命令"
12+
13+
``` sh
14+
python allen_cahn_default.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/allen_cahn/allen_cahn_default_pretrained.pdparams
15+
```
16+
17+
=== "模型导出命令"
18+
19+
``` sh
20+
python allen_cahn_default.py mode=export
21+
```
22+
23+
=== "模型推理命令"
24+
25+
``` sh
26+
python allen_cahn_default.py mode=infer
27+
```
28+
29+
| 预训练模型 | 指标 |
30+
|:--| :--|
31+
| [allen_cahn_default_pretrained.pdparams](TODO) | TODO |
32+
33+
## 1. 背景简介
34+
35+
Allen-Cahn 方程(有时也叫作模型方程或相场方程)是一种数学模型,通常用于描述两种不同相之间的界面演化。这个方程最早由Samuel Allen和John Cahn在1970年代提出,用以描述合金中相分离的过程。Allen-Cahn 方程是一种非线性偏微分方程,其一般形式可以写为:
36+
37+
$$ \frac{\partial u}{\partial t} = \varepsilon^2 \Delta u - F'(u) $$
38+
39+
这里:
40+
41+
- $u(\mathbf{x},t)$ 是一个场变量,代表某个物理量,例如合金的组分浓度或者晶体中的有序参数。
42+
- $t$ 表示时间。
43+
- $\mathbf{x}$ 表示空间位置。
44+
- $\Delta$ 是Laplace算子,对应于空间变量的二阶偏导数(即 $\Delta u = \nabla^2 u$ ),用来描述空间扩散过程。
45+
- $\varepsilon$ 是一个正的小参数,它与相界面的宽度相关。
46+
- $F(u)$ 是一个双稳态势能函数,通常取为$F(u) = \frac{1}{4}(u^2-1)^2$,这使得 $F'(u) = u^3 - u$ 是其导数,这代表了非线性的反应项,负责驱动系统向稳定状态演化。
47+
48+
这个方程中的 $F'(u)$ 项使得在 $u=1$ 和 $u=-1$ 附近有两个稳定的平衡态,这对应于不同的物理相。而 $\varepsilon^2 \Delta u$ 项则描述了相界面的曲率引起的扩散效应,这导致界面趋向于减小曲率。因此,Allen-Cahn 方程描述了由于相界面曲率和势能影响而发生的相变。
49+
50+
在实际应用中,该方程还可能包含边界条件和初始条件,以便对特定问题进行数值模拟和分析。例如,在特定的物理问题中,可能会有 Neumann 边界条件(导数为零,表示无通量穿过边界)或 Dirichlet 边界条件(固定的边界值)。
51+
52+
本案例解决以下 Allen-Cahn 方程:
53+
54+
$$
55+
\begin{aligned}
56+
& u_t - 0.0001 u_{xx} + 5 u^3 - 5 u = 0,\quad t \in [0, 1],\ x\in[-1, 1],\\
57+
&u(x,0) = x^2 \cos(\pi x),\\
58+
&u(t, -1) = u(t, 1),\\
59+
&u_x(t, -1) = u_x(t, 1).
60+
\end{aligned}
61+
$$
62+
63+
## 2. 问题定义
64+
65+
根据上述方程,可知计算域为$[0, 1]\times [-1, 1]$,含有一个初始条件: $u(x,0) = x^2 \cos(\pi x)$,两个周期边界条件:$u(t, -1) = u(t, 1)$、$u_x(t, -1) = u_x(t, 1)$。
66+
67+
## 3. 问题求解
68+
69+
接下来开始讲解如何将问题一步一步地转化为 PaddleScience 代码,用深度学习的方法求解该问题。
70+
为了快速理解 PaddleScience,接下来仅对模型构建、方程构建、计算域构建等关键步骤进行阐述,而其余细节请参考 [API文档](../api/arch.md)
71+
72+
### 3.1 模型构建
73+
74+
在 Allen-Cahn 问题中,每一个已知的坐标点 $(t, x)$ 都有对应的待求解的未知量 $(u)$,
75+
,在这里使用比较简单的 MLP(Multilayer Perceptron, 多层感知机) 来表示 $(t, x)$ 到 $(u)$ 的映射函数 $f: \mathbb{R}^2 \to \mathbb{R}^1$ ,即:
76+
77+
$$
78+
u = f(t, x)
79+
$$
80+
81+
上式中 $f$ 即为 MLP 模型本身,用 PaddleScience 代码表示如下
82+
83+
``` py linenums="63"
84+
--8<--
85+
examples/allen_cahn/allen_cahn_default.py:63:64
86+
--8<--
87+
```
88+
89+
为了在计算时,准确快速地访问具体变量的值,在这里指定网络模型的输入变量名是 `("t", "x")`,输出变量名是 `("u")`,这些命名与后续代码保持一致。
90+
91+
接着通过指定 MLP 的层数、神经元个数,就实例化出了一个拥有 4 层隐藏神经元,每层神经元数为 256 的神经网络模型 `model`,使用 `tanh` 作为激活函数。
92+
93+
``` yaml linenums="35"
94+
--8<--
95+
examples/allen_cahn/conf/allen_cahn_default.yaml:35:41
96+
--8<--
97+
```
98+
99+
### 3.2 方程构建
100+
101+
Allen-Cahn 微分方程可以用如下代码表示:
102+
103+
``` py linenums="66"
104+
--8<--
105+
examples/allen_cahn/allen_cahn_default.py:66:67
106+
--8<--
107+
```
108+
109+
### 3.3 计算域构建
110+
111+
本问题的计算域为 $[0, 1]\times [-1, 1]$,其中用于训练的数据已提前生成,保存在 `./dataset/allen_cahn.mat` 中,读取并生成计算域内的离散点。
112+
113+
``` py linenums="69"
114+
--8<--
115+
examples/allen_cahn/allen_cahn_default.py:69:81
116+
--8<--
117+
```
118+
119+
### 3.4 约束构建
120+
121+
#### 3.4.1 内部点约束
122+
123+
以作用在内部点上的 `SupervisedConstraint` 为例,代码如下:
124+
125+
``` py linenums="94"
126+
--8<--
127+
examples/allen_cahn/allen_cahn_default.py:94:110
128+
--8<--
129+
```
130+
131+
`SupervisedConstraint` 的第一个参数是用于训练的数据配置,由于我们使用实时随机生成的数据,而不是固定数据点,因此填入自定义的输入数据/标签生成函数;
132+
133+
第二个参数是方程表达式,因此传入 Allen-Cahn 的方程对象;
134+
135+
第三个参数是损失函数,此处选用 `CausalMSELoss` 函数,其会根据 `causal``tol` 参数,对不同的时间窗口进行重新加权, 能更好地优化瞬态问题;
136+
137+
第四个参数是约束条件的名字,需要给每一个约束条件命名,方便后续对其索引。此处命名为 "PDE" 即可。
138+
139+
#### 3.4.2 周期边界约束
140+
141+
此处我们采用 hard-constraint 的方式,在神经网络模型中,对输入数据使用cos、sin等周期函数进行周期化,从而让$u_{\theta}$在数学上直接满足方程的周期性质。
142+
根据方程可得函数$u(t, x)$在$x$轴上的周期为2,因此将该周期设置到模型配置里即可。
143+
144+
``` yaml linenums="35"
145+
--8<--
146+
examples/allen_cahn/conf/allen_cahn_default.yaml:35:43
147+
--8<--
148+
```
149+
150+
#### 3.4.3 初值约束
151+
152+
第三个约束条件是初值约束,代码如下:
153+
154+
``` py linenums="112"
155+
--8<--
156+
examples/allen_cahn/allen_cahn_default.py:112:125
157+
--8<--
158+
```
159+
160+
在微分方程约束、初值约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。
161+
162+
``` py linenums="126"
163+
--8<--
164+
examples/allen_cahn/allen_cahn_default.py:126:130
165+
--8<--
166+
```
167+
168+
### 3.5 超参数设定
169+
170+
接下来需要指定训练轮数和学习率,此处按实验经验,使用 200 轮训练轮数,0.001 的初始学习率。
171+
172+
``` yaml linenums="51"
173+
--8<--
174+
examples/allen_cahn/conf/allen_cahn_default.yaml:51:73
175+
--8<--
176+
```
177+
178+
### 3.6 优化器构建
179+
180+
训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,并配合使用机器学习中常用的 ExponentialDecay 学习率调整策略。
181+
182+
``` py linenums="132"
183+
--8<--
184+
examples/allen_cahn/allen_cahn_default.py:132:136
185+
--8<--
186+
```
187+
188+
### 3.7 评估器构建
189+
190+
在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器。
191+
192+
``` py linenums="138"
193+
--8<--
194+
examples/allen_cahn/allen_cahn_default.py:138:156
195+
--8<--
196+
```
197+
198+
### 3.9 模型训练、评估与可视化
199+
200+
完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练、评估、可视化。
201+
202+
``` py linenums="158"
203+
--8<--
204+
examples/allen_cahn/allen_cahn_default.py:158:194
205+
--8<--
206+
```
207+
208+
## 4. 完整代码
209+
210+
``` py linenums="1" title="allen_cahn_default.py"
211+
--8<--
212+
examples/allen_cahn/allen_cahn_default.py
213+
--8<--
214+
```
215+
216+
## 5. 结果展示
217+
218+
在计算域上均匀采样出 $201\times501$ 个点,其预测结果和解析解如下图所示。
219+
220+
<figure markdown>
221+
![allen_cahn_default.jpg](https://paddle-org.bj.bcebos.com/paddlescience/docs/AllenCahn/allen_cahn_default.png){ loading=lazy }
222+
<figcaption> 左侧为 PaddleScience 预测结果,中间为解析解结果,右侧为两者的差值</figcaption>
223+
</figure>
224+
225+
可以看到对于函数$u(t, x)$,模型的预测结果和解析解的结果基本一致。
226+
227+
## 6. 参考资料
228+
229+
- [Allen-Cahn equation](https://github.com/PredictiveIntelligenceLab/jaxpi/blob/main/examples/allen_cahn/README.md)

ppsci/loss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616

1717
from ppsci.loss.base import Loss
18+
from ppsci.loss.chamfer import ChamferLoss
1819
from ppsci.loss.func import FunctionalLoss
1920
from ppsci.loss.integral import IntegralLoss
2021
from ppsci.loss.kl import KLLoss
@@ -40,6 +41,7 @@
4041
"PeriodicL2Loss",
4142
"MAELoss",
4243
"CausalMSELoss",
44+
"ChamferLoss",
4345
"MSELoss",
4446
"MSELossWithL2Decay",
4547
"PeriodicMSELoss",

ppsci/loss/chamfer.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Dict
18+
from typing import Optional
19+
from typing import Union
20+
21+
import paddle
22+
23+
from ppsci.loss import base
24+
25+
26+
class ChamferLoss(base.Loss):
27+
r"""Class for Chamfe distance loss.
28+
29+
$$
30+
L = \dfrac{1}{S_1} \sum_{x \in S_1} \min_{y \in S_2} \Vert x - y \Vert_2^2 + \dfrac{1}{S_2} \sum_{y \in S_2} \min_{x \in S_1} \Vert y - x \Vert_2^2
31+
$$
32+
33+
$$
34+
\text{where } S_1 \text{ and } S_2 \text{ is the coordinate matrix of two point clouds}.
35+
$$
36+
37+
Args:
38+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
39+
40+
Examples:
41+
>>> import paddle
42+
>>> from ppsci.loss import ChamferLoss
43+
>>> _ = paddle.seed(42)
44+
>>> batch_point_cloud1 = paddle.rand([2, 100, 3])
45+
>>> batch_point_cloud2 = paddle.rand([2, 50, 3])
46+
>>> output_dict = {"s1": batch_point_cloud1}
47+
>>> label_dict = {"s1": batch_point_cloud2}
48+
>>> weight = {"s1": 0.8}
49+
>>> loss = ChamferLoss(weight=weight)
50+
>>> result = loss(output_dict, label_dict)
51+
>>> print(result)
52+
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
53+
0.04415882)
54+
"""
55+
56+
def __init__(
57+
self,
58+
weight: Optional[Union[float, Dict[str, float]]] = None,
59+
):
60+
super().__init__("mean", weight)
61+
62+
def forward(self, output_dict, label_dict, weight_dict=None):
63+
losses = 0.0
64+
for key in label_dict:
65+
s1 = output_dict[key]
66+
s2 = label_dict[key]
67+
N1, N2 = s1.shape[1], s2.shape[1]
68+
69+
# [B, N1, N2, 3]
70+
s1_expand = paddle.expand(s1.reshape([-1, N1, 1, 3]), shape=[-1, N1, N2, 3])
71+
# [B, N1, N2, 3]
72+
s2_expand = paddle.expand(s2.reshape([-1, 1, N2, 3]), shape=[-1, N1, N2, 3])
73+
74+
dis = ((s1_expand - s2_expand) ** 2).sum(axis=3) # [B, N1, N2]
75+
loss_s12 = dis.min(axis=2) # [B, N1]
76+
loss_s21 = dis.min(axis=1) # [B, N2]
77+
loss = loss_s12.mean() + loss_s21.mean()
78+
79+
if weight_dict and key in weight_dict:
80+
loss *= weight_dict[key]
81+
82+
if isinstance(self.weight, (float, int)):
83+
loss *= self.weight
84+
elif isinstance(self.weight, dict) and key in self.weight:
85+
loss *= self.weight[key]
86+
87+
losses += loss
88+
return losses

ppsci/loss/mse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class CausalMSELoss(base.Loss):
110110
where $w_i=\exp (-\epsilon \displaystyle\sum_{k=1}^{i-1} \mathcal{L}_r^k), i=2,3, \ldots, M.$
111111
112112
Args:
113-
n_chunks (int): Number of time windows split.
113+
n_chunks (int): $M$, Number of split time windows.
114114
reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
115115
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
116116
tol (float, optional): Causal tolerance, i.e. $\epsilon$ in paper. Defaults to 1.0.

0 commit comments

Comments
 (0)