@@ -403,14 +403,118 @@ def transform_out(in_, out):
403
403
)
404
404
405
405
406
+ def export (cfg : DictConfig ):
407
+ # set model
408
+ model_psi = ppsci .arch .MLP (** cfg .MODEL .psi_net )
409
+ model_p = ppsci .arch .MLP (** cfg .MODEL .p_net )
410
+ model_phil = ppsci .arch .MLP (** cfg .MODEL .phil_net )
411
+
412
+ # transform
413
+ def transform_out (in_ , out ):
414
+ psi_y = out ["psi" ]
415
+ y = in_ ["y" ]
416
+ x = in_ ["x" ]
417
+ u = jacobian (psi_y , y , create_graph = False )
418
+ v = - jacobian (psi_y , x , create_graph = False )
419
+ return {"u" : u , "v" : v }
420
+
421
+ # register transform
422
+ model_psi .register_output_transform (transform_out )
423
+ model_list = ppsci .arch .ModelList ((model_psi , model_p , model_phil ))
424
+
425
+ # initialize solver
426
+ solver = ppsci .solver .Solver (
427
+ model_list ,
428
+ pretrained_model_path = cfg .INFER .pretrained_model_path ,
429
+ )
430
+ # export model
431
+ from paddle .static import InputSpec
432
+
433
+ input_spec = [
434
+ {
435
+ key : InputSpec ([None , 1 ], "float32" , name = key )
436
+ for key in model_list .input_keys
437
+ },
438
+ ]
439
+ solver .export (input_spec , cfg .INFER .export_path )
440
+
441
+
442
+ def inference (cfg : DictConfig ):
443
+ # load Data
444
+ data = scipy .io .loadmat (cfg .DATA_PATH )
445
+ # normalize data
446
+ p_max = data ["p" ].max (axis = 0 )
447
+ p_min = data ["p" ].min (axis = 0 )
448
+ u_max = data ["u" ].max (axis = 0 )
449
+ u_min = data ["u" ].min (axis = 0 )
450
+ v_max = data ["v" ].max (axis = 0 )
451
+ v_min = data ["v" ].min (axis = 0 )
452
+
453
+ from deploy .python_infer import pinn_predictor
454
+
455
+ predictor = pinn_predictor .PINNPredictor (cfg )
456
+ # set time-geometry
457
+ timestamps = np .linspace (0 , 126 , 127 , endpoint = True )
458
+ geom = {
459
+ "time_rect_visu" : ppsci .geometry .TimeXGeometry (
460
+ ppsci .geometry .TimeDomain (1 , 126 , timestamps = timestamps ),
461
+ ppsci .geometry .Rectangle ((0 , 0 ), (15 , 5 )),
462
+ ),
463
+ }
464
+ NTIME_ALL = len (timestamps )
465
+ NPOINT_PDE , NTIME_PDE = 300 * 100 , NTIME_ALL - 1
466
+ input_dict = geom ["time_rect_visu" ].sample_interior (
467
+ NPOINT_PDE * NTIME_PDE , evenly = True
468
+ )
469
+ output_dict = predictor .predict (input_dict , cfg .INFER .batch_size )
470
+
471
+ # mapping data to cfg.INFER.output_keys
472
+ output_dict = {
473
+ store_key : output_dict [infer_key ]
474
+ for store_key , infer_key in zip (cfg .MODEL .output_keys , output_dict .keys ())
475
+ }
476
+
477
+ # inverse normalization
478
+ p_pred = output_dict ["p" ].reshape ([NTIME_PDE , NPOINT_PDE ]).T
479
+ u_pred = output_dict ["u" ].reshape ([NTIME_PDE , NPOINT_PDE ]).T
480
+ v_pred = output_dict ["v" ].reshape ([NTIME_PDE , NPOINT_PDE ]).T
481
+ pred = {
482
+ "p" : (p_pred * (p_max - p_min ) + p_min ).T .reshape ([- 1 , 1 ]),
483
+ "u" : (u_pred * (u_max - u_min ) + u_min ).T .reshape ([- 1 , 1 ]),
484
+ "v" : (v_pred * (v_max - v_min ) + v_min ).T .reshape ([- 1 , 1 ]),
485
+ "phil" : output_dict ["phil" ],
486
+ }
487
+ ppsci .visualize .save_vtu_from_dict (
488
+ "./visual/bubble_pred.vtu" ,
489
+ {
490
+ "t" : input_dict ["t" ],
491
+ "x" : input_dict ["x" ],
492
+ "y" : input_dict ["y" ],
493
+ "u" : pred ["u" ],
494
+ "v" : pred ["v" ],
495
+ "p" : pred ["p" ],
496
+ "phil" : pred ["phil" ],
497
+ },
498
+ ("t" , "x" , "y" ),
499
+ ("u" , "v" , "p" , "phil" ),
500
+ NTIME_PDE ,
501
+ )
502
+
503
+
406
504
@hydra .main (version_base = None , config_path = "./conf" , config_name = "bubble.yaml" )
407
505
def main (cfg : DictConfig ):
408
506
if cfg .mode == "train" :
409
507
train (cfg )
410
508
elif cfg .mode == "eval" :
411
509
evaluate (cfg )
510
+ elif cfg .mode == "export" :
511
+ export (cfg )
512
+ elif cfg .mode == "infer" :
513
+ inference (cfg )
412
514
else :
413
- raise ValueError (f"cfg.mode should in ['train', 'eval'], but got '{ cfg .mode } '" )
515
+ raise ValueError (
516
+ f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{ cfg .mode } '"
517
+ )
414
518
415
519
416
520
if __name__ == "__main__" :
0 commit comments