@@ -283,8 +283,8 @@ def export(cfg: DictConfig):
283
283
284
284
input_spec = [
285
285
{
286
- key : InputSpec ([None , 16 , 128 ], "float32" , name = key )
287
- for key in model . input_keys
286
+ "states" : InputSpec ([1 , 255 , 3 , 64 , 128 ], "float32" , name = "states" ),
287
+ "visc" : InputSpec ([ 1 , 1 ], "float32" , name = "visc" ),
288
288
},
289
289
]
290
290
@@ -309,6 +309,7 @@ def inference(cfg: DictConfig):
309
309
310
310
input_dict = {
311
311
"states" : dataset .data [: cfg .VIS_DATA_NUMS , :- 1 ],
312
+ "visc" : dataset .visc [: cfg .VIS_DATA_NUMS ],
312
313
}
313
314
314
315
output_dict = predictor .predict (input_dict )
@@ -319,17 +320,19 @@ def inference(cfg: DictConfig):
319
320
store_key : output_dict [infer_key ]
320
321
for store_key , infer_key in zip (output_keys , output_dict .keys ())
321
322
}
322
-
323
- input_dict = {
324
- "states" : dataset .data [: cfg .VIS_DATA_NUMS , 1 :],
325
- }
326
-
327
- data_dict = {** input_dict , ** output_dict }
328
323
for i in range (cfg .VIS_DATA_NUMS ):
329
- ppsci .visualize .save_plot_from_3d_dict (
324
+ ppsci .visualize .plot . save_plot_from_2d_dict (
330
325
f"./cylinder_transformer_pred_{ i } " ,
331
- {key : value [i ] for key , value in data_dict .items ()},
332
- ("states" , "pred_states" ),
326
+ {
327
+ "pred_ux" : output_dict ["pred_states" ][i ][:, 0 ],
328
+ "pred_uy" : output_dict ["pred_states" ][i ][:, 1 ],
329
+ "pred_p" : output_dict ["pred_states" ][i ][:, 2 ],
330
+ },
331
+ ("pred_ux" , "pred_uy" , "pred_p" ),
332
+ 10 ,
333
+ 20 ,
334
+ np .linspace (- 2 , 14 , 9 ),
335
+ np .linspace (- 4 , 4 , 5 ),
333
336
)
334
337
335
338
0 commit comments