Skip to content

Commit 2fea541

Browse files
support default config content in config module and remove deprecated AttrDict series code
1 parent dbdb813 commit 2fea541

28 files changed

+295
-431
lines changed

ppsci/arch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def build_model(cfg):
8181
"""Build model
8282
8383
Args:
84-
cfg (AttrDict): Arch config.
84+
cfg (DictConfig): Arch config.
8585
8686
Returns:
8787
nn.Layer: Model.

ppsci/arch/phycrnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
)
148148

149149
# ConvLSTM
150-
self.convlstm = paddle.nn.LayerList(
150+
self.ConvLSTM = paddle.nn.LayerList(
151151
[
152152
ConvLSTMCell(
153153
input_channels=self.input_channels[i],
@@ -194,16 +194,16 @@ def forward(self, x):
194194
x = encoder(x)
195195

196196
# convlstm
197-
for i, lstm in enumerate(self.convlstm, self.num_encoder):
197+
for i, LSTM in enumerate(self.ConvLSTM):
198198
if step == 0:
199-
(h, c) = lstm.init_hidden_tensor(
199+
(h, c) = LSTM.init_hidden_tensor(
200200
prev_state=self.initial_state[i - self.num_encoder]
201201
)
202202
internal_state.append((h, c))
203203

204204
# one-step forward
205205
(h, c) = internal_state[i - self.num_encoder]
206-
x, new_c = lstm(x, h, c)
206+
x, new_c = LSTM(x, h, c)
207207
internal_state[i - self.num_encoder] = (x, new_c)
208208

209209
# output

ppsci/constraint/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def build_constraint(cfg, equation_dict, geom_dict):
4242
"""Build constraint(s).
4343
4444
Args:
45-
cfg (List[AttrDict]): Constraint config list.
45+
cfg (List[DictConfig]): Constraint config list.
4646
equation_dict (Dct[str, Equation]): Equation(s) in dict.
4747
geom_dict (Dct[str, Geometry]): Geometry(ies) in dict.
4848

ppsci/data/dataset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def build_dataset(cfg) -> "io.Dataset":
7878
"""Build dataset
7979
8080
Args:
81-
cfg (List[AttrDict]): dataset config list.
81+
cfg (List[DictConfig]): dataset config list.
8282
8383
Returns:
8484
Dict[str, io.Dataset]: dataset.

ppsci/equation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_equation(cfg):
5454
"""Build equation(s)
5555
5656
Args:
57-
cfg (List[AttrDict]): Equation(s) config list.
57+
cfg (List[DictConfig]): Equation(s) config list.
5858
5959
Returns:
6060
Dict[str, Equation]: Equation(s) in dict.

ppsci/equation/pde/base.py

Lines changed: 15 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Union
2323

2424
import paddle
25-
import sympy as sp
25+
import sympy
2626
from paddle import nn
2727

2828
DETACH_FUNC_NAME = "detach"
@@ -33,7 +33,7 @@ class PDE:
3333

3434
def __init__(self):
3535
super().__init__()
36-
self.equations: Dict[str, Union[Callable, sp.Basic]] = {}
36+
self.equations = {}
3737
# for PDE which has learnable parameter(s)
3838
self.learnable_parameters = nn.ParameterList()
3939

@@ -42,7 +42,7 @@ def __init__(self):
4242
@staticmethod
4343
def create_symbols(
4444
symbol_str: str,
45-
) -> Union[sp.Symbol, Tuple[sp.Symbol, ...]]:
45+
) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
4646
"""create symbolic variables.
4747
4848
Args:
@@ -61,9 +61,11 @@ def create_symbols(
6161
>>> print(symbols_xyz)
6262
(x, y, z)
6363
"""
64-
return sp.symbols(symbol_str)
64+
return sympy.symbols(symbol_str)
6565

66-
def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Function:
66+
def create_function(
67+
self, name: str, invars: Tuple[sympy.Symbol, ...]
68+
) -> sympy.Function:
6769
"""Create named function depending on given invars.
6870
6971
Args:
@@ -84,73 +86,14 @@ def create_function(self, name: str, invars: Tuple[sp.Symbol, ...]) -> sp.Functi
8486
>>> print(f)
8587
f(x, y, z)
8688
"""
87-
expr = sp.Function(name)(*invars)
89+
expr = sympy.Function(name)(*invars)
8890

91+
# wrap `expression(...)` to `detach(expression(...))`
92+
# if name of expression is in given detach_keys
93+
if self.detach_keys and name in self.detach_keys:
94+
expr = sympy.Function(DETACH_FUNC_NAME)(expr)
8995
return expr
9096

91-
def _apply_detach(self):
92-
"""
93-
Wrap detached sub_expr into detach(sub_expr) to prevent gradient
94-
back-propagation, only for those items speicified in self.detach_keys.
95-
96-
NOTE: This function is expected to be called after self.equations is ready in PDE.__init__.
97-
98-
Examples:
99-
>>> import ppsci
100-
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False)
101-
>>> print(ns)
102-
NavierStokes
103-
continuity: Derivative(u(x, y), x) + Derivative(v(x, y), y)
104-
momentum_x: u(x, y)*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
105-
momentum_y: u(x, y)*Derivative(v(x, y), x) + v(x, y)*Derivative(v(x, y), y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
106-
>>> detach_keys = ("u", "v__y")
107-
>>> ns = ppsci.equation.NavierStokes(1.0, 1.0, 2, False, detach_keys=detach_keys)
108-
>>> print(ns)
109-
NavierStokes
110-
continuity: detach(Derivative(v(x, y), y)) + Derivative(u(x, y), x)
111-
momentum_x: detach(u(x, y))*Derivative(u(x, y), x) + v(x, y)*Derivative(u(x, y), y) + 1.0*Derivative(p(x, y), x) - 1.0*Derivative(u(x, y), (x, 2)) - 1.0*Derivative(u(x, y), (y, 2))
112-
momentum_y: detach(u(x, y))*Derivative(v(x, y), x) + detach(Derivative(v(x, y), y))*v(x, y) + 1.0*Derivative(p(x, y), y) - 1.0*Derivative(v(x, y), (x, 2)) - 1.0*Derivative(v(x, y), (y, 2))
113-
"""
114-
if self.detach_keys is None:
115-
return
116-
117-
from copy import deepcopy
118-
119-
from sympy.core.traversal import postorder_traversal
120-
121-
from ppsci.utils.symbolic import _cvt_to_key
122-
123-
for name, expr in self.equations.items():
124-
if not isinstance(expr, sp.Basic):
125-
continue
126-
# only process sympy expression
127-
expr_ = deepcopy(expr)
128-
for item in postorder_traversal(expr):
129-
if _cvt_to_key(item) in self.detach_keys:
130-
# inplace all related sub_expr into detach(sub_expr)
131-
expr_ = expr_.replace(item, sp.Function(DETACH_FUNC_NAME)(item))
132-
133-
# remove all detach wrapper for more-than-once wrapped items to prevent duplicated wrapping
134-
expr_ = expr_.replace(
135-
sp.Function(DETACH_FUNC_NAME)(
136-
sp.Function(DETACH_FUNC_NAME)(item)
137-
),
138-
sp.Function(DETACH_FUNC_NAME)(item),
139-
)
140-
141-
# remove unccessary detach wrapping for the first arg of Derivative
142-
for item_ in list(postorder_traversal(expr_)):
143-
if isinstance(item_, sp.Derivative):
144-
if item_.args[0].name == DETACH_FUNC_NAME:
145-
expr_ = expr_.replace(
146-
item_,
147-
sp.Derivative(
148-
item_.args[0].args[0], *item_.args[1:]
149-
),
150-
)
151-
152-
self.equations[name] = expr_
153-
15497
def add_equation(self, name: str, equation: Callable):
15598
"""Add an equation.
15699
@@ -167,8 +110,7 @@ def add_equation(self, name: str, equation: Callable):
167110
>>> equation = sympy.diff(u, x) + sympy.diff(u, y)
168111
>>> pde.add_equation('linear_pde', equation)
169112
>>> print(pde)
170-
PDE
171-
linear_pde: 2*x + 2*y
113+
PDE, linear_pde: 2*x + 2*y
172114
"""
173115
self.equations.update({name: equation})
174116

@@ -239,7 +181,7 @@ def set_state_dict(
239181
return self.learnable_parameters.set_state_dict(state_dict)
240182

241183
def __str__(self):
242-
return "\n".join(
184+
return ", ".join(
243185
[self.__class__.__name__]
244-
+ [f" {name}: {eq}" for name, eq in self.equations.items()]
186+
+ [f"{name}: {eq}" for name, eq in self.equations.items()]
245187
)

ppsci/equation/pde/biharmonic.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,3 @@ def __init__(
7070
biharmonic += u.diff(invar_i, 2).diff(invar_j, 2)
7171

7272
self.add_equation("biharmonic", biharmonic)
73-
74-
self._apply_detach()

ppsci/equation/pde/heat_exchanger.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,3 @@ def __init__(
9090
self.add_equation("heat_boundary", heat_boundary)
9191
self.add_equation("cold_boundary", cold_boundary)
9292
self.add_equation("wall", wall)
93-
94-
self._apply_detach()

ppsci/equation/pde/laplace.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,3 @@ def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
5151
laplace += u.diff(invar, 2)
5252

5353
self.add_equation("laplace", laplace)
54-
55-
self._apply_detach()

ppsci/equation/pde/linear_elasticity.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,3 @@ def __init__(
179179
self.add_equation("traction_y", traction_y)
180180
if self.dim == 3:
181181
self.add_equation("traction_z", traction_z)
182-
183-
self._apply_detach()

0 commit comments

Comments
 (0)