Skip to content

Commit 637b4e7

Browse files
support dim=1/2/3 for SPINN and helmholtz PDE (#1075)
1 parent 78bd343 commit 637b4e7

File tree

4 files changed

+77
-15
lines changed

4 files changed

+77
-15
lines changed

examples/spinn/helmholtz3d.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ def train(cfg: DictConfig):
100100
model = ppsci.arch.SPINN(**cfg.MODEL)
101101

102102
# set equation
103-
equation = {"Helmholtz": ppsci.equation.Helmholtz(3, 1.0)}
104-
equation["Helmholtz"].model = model # set model to equation for hvp
103+
equation = {"Helmholtz": ppsci.equation.Helmholtz(3, 1.0, model)}
105104

106105
# set constraint
107106
class InteriorDataGenerator:

ppsci/arch/spinn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ def _tensor_contraction(self, x: paddle.Tensor, y: paddle.Tensor) -> paddle.Tens
137137

138138
return out
139139

140-
def forward_tensor(self, x, y, z) -> List[paddle.Tensor]:
140+
def forward_tensor(self, *xs) -> List[paddle.Tensor]:
141141
# forward each dim branch
142142
feature_f = []
143-
for i, input_var in enumerate((x, y, z)):
143+
for i, input_var in enumerate(xs):
144144
input_i = {self.input_keys[i]: input_var}
145145
output_f_i = self.branch_nets[i](input_i)
146146
feature_f.append(output_f_i["f"]) # [B, r*output_dim]

ppsci/equation/pde/helmholtz.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,50 @@ def __init__(
6666
self,
6767
dim: int,
6868
k: float,
69+
model: paddle.nn.Layer,
6970
detach_keys: Optional[Tuple[str, ...]] = None,
7071
):
7172
super().__init__()
7273
self.dim = dim
7374
self.k = k
7475
self.detach_keys = detach_keys
7576

76-
self.model: paddle.nn.Layer
77+
invars = self.create_symbols("x y z")[:dim]
7778

78-
def helmholtz(data_dict: Dict[str, "paddle.Tensor"]):
79-
x, y, z = (
80-
data_dict["x"],
81-
data_dict["y"],
82-
data_dict["z"],
83-
)
79+
# TODO: This is a hack, should be simplified in the future
80+
self.model = model
81+
82+
def helmholtz(data_dict: Dict[str, paddle.Tensor]) -> paddle.Tensor:
83+
xs = tuple(data_dict[invar.name] for invar in invars)
8484

8585
# TODO: Hard code here, for hvp_revrev requires tuple input(s) but not dict
86-
u__x__x = hvp_revrev(lambda x_: self.model.forward_tensor(x_, y, z), (x,))
87-
u__y__y = hvp_revrev(lambda y_: self.model.forward_tensor(x, y_, z), (y,))
88-
u__z__z = hvp_revrev(lambda z_: self.model.forward_tensor(x, y, z_), (z,))
86+
if self.dim == 1:
87+
u__x__x = hvp_revrev(lambda x_: self.model.forward_tensor(x_), (xs[0],))
88+
out = (self.k**2) * data_dict["u"] + u__x__x
89+
elif self.dim == 2:
90+
u__x__x = hvp_revrev(
91+
lambda x_: self.model.forward_tensor(x_, xs[1]), (xs[0],)
92+
)
93+
u__y__y = hvp_revrev(
94+
lambda y_: self.model.forward_tensor(xs[0], y_), (xs[1],)
95+
)
96+
out = (self.k**2) * data_dict["u"] + u__x__x + u__y__y
97+
elif self.dim >= 3:
98+
u__x__x = hvp_revrev(
99+
lambda x_: self.model.forward_tensor(x_, xs[1], xs[2]), (xs[0],)
100+
)
101+
u__y__y = hvp_revrev(
102+
lambda y_: self.model.forward_tensor(xs[0], y_, xs[2]), (xs[1],)
103+
)
104+
u__z__z = hvp_revrev(
105+
lambda z_: self.model.forward_tensor(xs[0], xs[1], z_), (xs[2],)
106+
)
107+
out = (self.k**2) * data_dict["u"] + u__x__x + u__y__y + u__z__z
108+
else:
109+
raise NotImplementedError(
110+
f"dim should be less or equal to 3, but got {self.dim}."
111+
)
89112

90-
out = (self.k**2) * data_dict["u"] + u__x__x + u__y__y + u__z__z
91113
return out
92114

93115
self.add_equation("helmholtz", helmholtz)

test/equation/test_helmholtz.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import paddle
3+
import pytest
4+
5+
from ppsci import arch
6+
from ppsci import equation
7+
8+
__all__ = []
9+
10+
11+
@pytest.mark.parametrize("dim", (1, 2, 3))
12+
def test_helmholtz(dim):
13+
"""Test for only mean."""
14+
nrs = [3, 4, 5][:dim]
15+
input_keys = ("x", "y", "z")[:dim]
16+
output_keys = ("u",)
17+
18+
# generate input data
19+
input_dict = {
20+
input_keys[i]: paddle.to_tensor(
21+
np.random.randn(nrs[i], 1).astype(np.float32), stop_gradient=False
22+
)
23+
for i in range(dim)
24+
}
25+
model = arch.SPINN(input_keys, output_keys, r=16, num_layers=2, hidden_size=8)
26+
y = model(input_dict)["u"]
27+
assert y.shape == [*nrs, 1]
28+
data_dict = {
29+
**input_dict,
30+
"u": y,
31+
}
32+
helmholtz_obj = equation.Helmholtz(dim, 1.0, model)
33+
34+
helmholtz_out = helmholtz_obj.equations["helmholtz"](data_dict)
35+
36+
# check result whether is equal
37+
assert helmholtz_out.shape == [*nrs, 1]
38+
39+
40+
if __name__ == "__main__":
41+
pytest.main()

0 commit comments

Comments
 (0)