@@ -89,66 +89,10 @@ def train(cfg: DictConfig):
89
89
# evaluate after finished training
90
90
solver .eval ()
91
91
92
- # visualize prediction for different functions u and corresponding G(u)
93
- dtype = paddle .get_default_dtype ()
94
-
95
- def generate_y_u_G_ref (
96
- u_func : Callable , G_u_func : Callable
97
- ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
98
- """Generate discretized data of given function u and corresponding G(u).
99
-
100
- Args:
101
- u_func (Callable): Function u.
102
- G_u_func (Callable): Function G(u).
92
+ def predict_func (input_dict ):
93
+ return solver .predict (input_dict , return_numpy = True )[cfg .MODEL .G_key ]
103
94
104
- Returns:
105
- Tuple[np.ndarray, np.ndarray, np.ndarray]: Discretized data of u, y and G(u).
106
- """
107
- x = np .linspace (0 , 1 , cfg .MODEL .num_loc , dtype = dtype ).reshape (
108
- [1 , cfg .MODEL .num_loc ]
109
- )
110
- u = u_func (x )
111
- u = np .tile (u , [cfg .NUM_Y , 1 ])
112
-
113
- y = np .linspace (0 , 1 , cfg .NUM_Y , dtype = dtype ).reshape ([cfg .NUM_Y , 1 ])
114
- G_ref = G_u_func (y )
115
- return u , y , G_ref
116
-
117
- func_u_G_pair = [
118
- # (title_string, func_u, func_G(u)), s.t. dG/dx == u and G(u)(0) = 0
119
- (r"$u=\cos(x), G(u)=sin(x$)" , lambda x : np .cos (x ), lambda y : np .sin (y )), # 1
120
- (
121
- r"$u=sec^2(x), G(u)=tan(x$)" ,
122
- lambda x : (1 / np .cos (x )) ** 2 ,
123
- lambda y : np .tan (y ),
124
- ), # 2
125
- (
126
- r"$u=sec(x)tan(x), G(u)=sec(x) - 1$" ,
127
- lambda x : (1 / np .cos (x ) * np .tan (x )),
128
- lambda y : 1 / np .cos (y ) - 1 ,
129
- ), # 3
130
- (
131
- r"$u=1.5^x\ln{1.5}, G(u)=1.5^x-1$" ,
132
- lambda x : 1.5 ** x * np .log (1.5 ),
133
- lambda y : 1.5 ** y - 1 ,
134
- ), # 4
135
- (r"$u=3x^2, G(u)=x^3$" , lambda x : 3 * x ** 2 , lambda y : y ** 3 ), # 5
136
- (r"$u=4x^3, G(u)=x^4$" , lambda x : 4 * x ** 3 , lambda y : y ** 4 ), # 6
137
- (r"$u=5x^4, G(u)=x^5$" , lambda x : 5 * x ** 4 , lambda y : y ** 5 ), # 7
138
- (r"$u=6x^5, G(u)=x^6$" , lambda x : 5 * x ** 4 , lambda y : y ** 5 ), # 8
139
- (r"$u=e^x, G(u)=e^x-1$" , lambda x : np .exp (x ), lambda y : np .exp (y ) - 1 ), # 9
140
- ]
141
-
142
- os .makedirs (os .path .join (cfg .output_dir , "visual" ), exist_ok = True )
143
- for i , (title , u_func , G_func ) in enumerate (func_u_G_pair ):
144
- u , y , G_ref = generate_y_u_G_ref (u_func , G_func )
145
- G_pred = solver .predict ({"u" : u , "y" : y }, return_numpy = True )["G" ]
146
- plt .plot (y , G_pred , label = r"$G(u)(y)_{ref}$" )
147
- plt .plot (y , G_ref , label = r"$G(u)(y)_{pred}$" )
148
- plt .legend ()
149
- plt .title (title )
150
- plt .savefig (os .path .join (cfg .output_dir , "visual" , f"func_{ i } _result.png" ))
151
- plt .clf ()
95
+ plot (cfg , predict_func )
152
96
153
97
154
98
def evaluate (cfg : DictConfig ):
@@ -189,6 +133,50 @@ def evaluate(cfg: DictConfig):
189
133
)
190
134
solver .eval ()
191
135
136
+ def predict_func (input_dict ):
137
+ return solver .predict (input_dict , return_numpy = True )[cfg .MODEL .G_key ]
138
+
139
+ plot (cfg , predict_func )
140
+
141
+
142
+ def export (cfg : DictConfig ):
143
+ # set model
144
+ model = ppsci .arch .DeepONet (** cfg .MODEL )
145
+
146
+ # initialize solver
147
+ solver = ppsci .solver .Solver (
148
+ model ,
149
+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
150
+ )
151
+
152
+ # export model
153
+ from paddle .static import InputSpec
154
+
155
+ input_spec = [
156
+ {
157
+ model .input_keys [0 ]: InputSpec (
158
+ [None , 1000 ], "float32" , name = model .input_keys [0 ]
159
+ ),
160
+ model .input_keys [1 ]: InputSpec (
161
+ [None , 1 ], "float32" , name = model .input_keys [1 ]
162
+ ),
163
+ }
164
+ ]
165
+ solver .export (input_spec , cfg .INFER .export_path )
166
+
167
+
168
+ def inference (cfg : DictConfig ):
169
+ from deploy import python_infer
170
+
171
+ predictor = python_infer .GeneralPredictor (cfg )
172
+
173
+ def predict_func (input_dict ):
174
+ return next (iter (predictor .predict (input_dict ).values ()))
175
+
176
+ plot (cfg , predict_func )
177
+
178
+
179
+ def plot (cfg : DictConfig , predict_func : Callable ):
192
180
# visualize prediction for different functions u and corresponding G(u)
193
181
dtype = paddle .get_default_dtype ()
194
182
@@ -242,13 +230,17 @@ def generate_y_u_G_ref(
242
230
os .makedirs (os .path .join (cfg .output_dir , "visual" ), exist_ok = True )
243
231
for i , (title , u_func , G_func ) in enumerate (func_u_G_pair ):
244
232
u , y , G_ref = generate_y_u_G_ref (u_func , G_func )
245
- G_pred = solver . predict ({"u" : u , "y" : y }, return_numpy = True )[ "G" ]
233
+ G_pred = predict_func ({"u" : u , "y" : y })
246
234
plt .plot (y , G_pred , label = r"$G(u)(y)_{ref}$" )
247
235
plt .plot (y , G_ref , label = r"$G(u)(y)_{pred}$" )
248
236
plt .legend ()
249
237
plt .title (title )
250
238
plt .savefig (os .path .join (cfg .output_dir , "visual" , f"func_{ i } _result.png" ))
239
+ logger .message (
240
+ f"Saved result of function { i } to { cfg .output_dir } /visual/func_{ i } _result.png"
241
+ )
251
242
plt .clf ()
243
+ plt .close ()
252
244
253
245
254
246
@hydra .main (version_base = None , config_path = "./conf" , config_name = "deeponet.yaml" )
@@ -257,8 +249,14 @@ def main(cfg: DictConfig):
257
249
train (cfg )
258
250
elif cfg .mode == "eval" :
259
251
evaluate (cfg )
252
+ elif cfg .mode == "export" :
253
+ export (cfg )
254
+ elif cfg .mode == "infer" :
255
+ inference (cfg )
260
256
else :
261
- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
257
+ raise ValueError (
258
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
259
+ )
262
260
263
261
264
262
if __name__ == "__main__" :
0 commit comments