@@ -83,19 +83,71 @@ def add_equation(self, name: str, equation: Callable):
83
83
self .equations .update ({name : equation })
84
84
85
85
def parameters (self ) -> List [paddle .Tensor ]:
86
- """Return parameters contained in PDE.
86
+ """Return learnable parameters contained in PDE.
87
+
88
+ Args:
89
+ None
87
90
88
91
Returns:
89
- List[Tensor]: A list of parameters.
92
+ List[Tensor]: A list of learnable parameters.
93
+
94
+ Examples:
95
+ >>> import ppsci
96
+ >>> pde = ppsci.equation.Vibration(2, -4, 0)
97
+ >>> print(pde.parameters())
98
+ [Parameter containing:
99
+ Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,
100
+ -4.), Parameter containing:
101
+ Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,
102
+ 0.)]
90
103
"""
91
104
return self .learnable_parameters .parameters ()
92
105
93
106
def state_dict (self ) -> Dict [str , paddle .Tensor ]:
94
- """Return named parameters in dict."""
107
+ """Return named learnable parameters in dict.
108
+
109
+ Args:
110
+ None
111
+
112
+ Returns:
113
+ Dict[str, Tensor]: A dict of states(str) and learnable parameters(Tensor).
114
+
115
+ Examples:
116
+ >>> import ppsci
117
+ >>> pde = ppsci.equation.Vibration(2, -4, 0)
118
+ >>> print(pde.state_dict())
119
+ OrderedDict([('0', Parameter containing:
120
+ Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
121
+ -4.)), ('1', Parameter containing:
122
+ Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
123
+ 0.))])
124
+ """
125
+
95
126
return self .learnable_parameters .state_dict ()
96
127
97
- def set_state_dict (self , state_dict ):
98
- """Set state dict from dict."""
128
+ def set_state_dict (self , state_dict : Dict [str , paddle .Tensor ]):
129
+ """Set state dict from dict.
130
+
131
+ Args:
132
+ state_dict (Dict[str, paddle.Tensor]): The state dict to be set.
133
+
134
+ Returns:
135
+ None
136
+
137
+ Examples:
138
+ >>> import paddle
139
+ >>> import ppsci
140
+ >>> paddle.set_default_dtype("float64")
141
+ >>> pde = ppsci.equation.Vibration(2, -4, 0)
142
+ >>> state = pde.state_dict()
143
+ >>> state['0'] = paddle.to_tensor(-3.1)
144
+ >>> pde.set_state_dict(state)
145
+ >>> print(state)
146
+ OrderedDict([('0', Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True,
147
+ -3.10000000)), ('1', Parameter containing:
148
+ Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=False,
149
+ 0.))])
150
+ """
99
151
self .learnable_parameters .set_state_dict (state_dict )
100
152
101
153
def __str__ (self ):
0 commit comments