@@ -409,14 +409,79 @@ def evaluate(cfg: DictConfig):
409
409
solver .eval ()
410
410
411
411
412
+ def export (cfg : DictConfig ):
413
+ # set model
414
+ model_re = ppsci .arch .MLP (** cfg .MODEL .re_net )
415
+ model_im = ppsci .arch .MLP (** cfg .MODEL .im_net )
416
+ model_eps = ppsci .arch .MLP (** cfg .MODEL .eps_net )
417
+
418
+ # register transform
419
+ model_re .register_input_transform (func_module .transform_in )
420
+ model_im .register_input_transform (func_module .transform_in )
421
+ model_eps .register_input_transform (func_module .transform_in )
422
+
423
+ model_re .register_output_transform (func_module .transform_out_real_part )
424
+ model_im .register_output_transform (func_module .transform_out_imaginary_part )
425
+ model_eps .register_output_transform (func_module .transform_out_epsilon )
426
+
427
+ # wrap to a model_list
428
+ model_list = ppsci .arch .ModelList ((model_re , model_im , model_eps ))
429
+
430
+ # initialize solver
431
+ solver = ppsci .solver .Solver (
432
+ model_list ,
433
+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
434
+ )
435
+
436
+ # export model
437
+ from paddle .static import InputSpec
438
+
439
+ input_spec = [
440
+ {key : InputSpec ([None , 1 ], "float32" , name = key ) for key in ["x" , "y" ]},
441
+ ]
442
+ solver .export (input_spec , cfg .INFER .export_path )
443
+
444
+
445
+ def inference (cfg : DictConfig ):
446
+ from deploy .python_infer import pinn_predictor
447
+
448
+ predictor = pinn_predictor .PINNPredictor (cfg )
449
+
450
+ valid_dict = ppsci .utils .reader .load_mat_file (
451
+ cfg .DATASET_PATH_VALID , ("x_val" , "y_val" , "bound" )
452
+ )
453
+ input_dict = {"x" : valid_dict ["x_val" ], "y" : valid_dict ["y_val" ]}
454
+
455
+ output_dict = predictor .predict (input_dict , cfg .INFER .batch_size )
456
+
457
+ # mapping data to cfg.INFER.output_keys
458
+ output_dict = {
459
+ store_key : output_dict [infer_key ]
460
+ for store_key , infer_key in zip (cfg .INFER .output_keys , output_dict .keys ())
461
+ }
462
+
463
+ ppsci .visualize .save_vtu_from_dict (
464
+ "./hpinns_pred.vtu" ,
465
+ {** input_dict , ** output_dict },
466
+ input_dict .keys (),
467
+ cfg .INFER .output_keys ,
468
+ )
469
+
470
+
412
471
@hydra .main (version_base = None , config_path = "./conf" , config_name = "hpinns.yaml" )
413
472
def main (cfg : DictConfig ):
414
473
if cfg .mode == "train" :
415
474
train (cfg )
416
475
elif cfg .mode == "eval" :
417
476
evaluate (cfg )
477
+ elif cfg .mode == "export" :
478
+ export (cfg )
479
+ elif cfg .mode == "infer" :
480
+ inference (cfg )
418
481
else :
419
- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
482
+ raise ValueError (
483
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
484
+ )
420
485
421
486
422
487
if __name__ == "__main__" :
0 commit comments