Skip to content

Commit 65f4447

Browse files
【PPSCI Doc No.38-40】 (#826)
* ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix * ppsci.equation.PDE.parameters/state_dict/set_state_dict api fix --------- Co-authored-by: krp <[email protected]>
1 parent 166caf5 commit 65f4447

File tree

1 file changed

+57
-5
lines changed

1 file changed

+57
-5
lines changed

ppsci/equation/pde/base.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,71 @@ def add_equation(self, name: str, equation: Callable):
8383
self.equations.update({name: equation})
8484

8585
def parameters(self) -> List[paddle.Tensor]:
86-
"""Return parameters contained in PDE.
86+
"""Return learnable parameters contained in PDE.
87+
88+
Args:
89+
None
8790
8891
Returns:
89-
List[Tensor]: A list of parameters.
92+
List[Tensor]: A list of learnable parameters.
93+
94+
Examples:
95+
>>> import ppsci
96+
>>> pde = ppsci.equation.Vibration(2, -4, 0)
97+
>>> print(pde.parameters())
98+
[Parameter containing:
99+
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,
100+
-4.), Parameter containing:
101+
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,
102+
0.)]
90103
"""
91104
return self.learnable_parameters.parameters()
92105

93106
def state_dict(self) -> Dict[str, paddle.Tensor]:
94-
"""Return named parameters in dict."""
107+
"""Return named learnable parameters in dict.
108+
109+
Args:
110+
None
111+
112+
Returns:
113+
Dict[str, Tensor]: A dict of states(str) and learnable parameters(Tensor).
114+
115+
Examples:
116+
>>> import ppsci
117+
>>> pde = ppsci.equation.Vibration(2, -4, 0)
118+
>>> print(pde.state_dict())
119+
OrderedDict([('0', Parameter containing:
120+
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
121+
-4.)), ('1', Parameter containing:
122+
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
123+
0.))])
124+
"""
125+
95126
return self.learnable_parameters.state_dict()
96127

97-
def set_state_dict(self, state_dict):
98-
"""Set state dict from dict."""
128+
def set_state_dict(self, state_dict: Dict[str, paddle.Tensor]):
129+
"""Set state dict from dict.
130+
131+
Args:
132+
state_dict (Dict[str, paddle.Tensor]): The state dict to be set.
133+
134+
Returns:
135+
None
136+
137+
Examples:
138+
>>> import paddle
139+
>>> import ppsci
140+
>>> paddle.set_default_dtype("float64")
141+
>>> pde = ppsci.equation.Vibration(2, -4, 0)
142+
>>> state = pde.state_dict()
143+
>>> state['0'] = paddle.to_tensor(-3.1)
144+
>>> pde.set_state_dict(state)
145+
>>> print(state)
146+
OrderedDict([('0', Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True,
147+
-3.10000000)), ('1', Parameter containing:
148+
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
149+
0.))])
150+
"""
99151
self.learnable_parameters.set_state_dict(state_dict)
100152

101153
def __str__(self):

0 commit comments

Comments
 (0)