Skip to content

Commit 2fdee45

Browse files
fix export and infer (#916)
1 parent 3729e14 commit 2fdee45

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

examples/cylinder/2d_unsteady/transformer_physx/train_transformer.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ def export(cfg: DictConfig):
283283

284284
input_spec = [
285285
{
286-
key: InputSpec([None, 16, 128], "float32", name=key)
287-
for key in model.input_keys
286+
"states": InputSpec([1, 255, 3, 64, 128], "float32", name="states"),
287+
"visc": InputSpec([1, 1], "float32", name="visc"),
288288
},
289289
]
290290

@@ -309,6 +309,7 @@ def inference(cfg: DictConfig):
309309

310310
input_dict = {
311311
"states": dataset.data[: cfg.VIS_DATA_NUMS, :-1],
312+
"visc": dataset.visc[: cfg.VIS_DATA_NUMS],
312313
}
313314

314315
output_dict = predictor.predict(input_dict)
@@ -319,17 +320,19 @@ def inference(cfg: DictConfig):
319320
store_key: output_dict[infer_key]
320321
for store_key, infer_key in zip(output_keys, output_dict.keys())
321322
}
322-
323-
input_dict = {
324-
"states": dataset.data[: cfg.VIS_DATA_NUMS, 1:],
325-
}
326-
327-
data_dict = {**input_dict, **output_dict}
328323
for i in range(cfg.VIS_DATA_NUMS):
329-
ppsci.visualize.save_plot_from_3d_dict(
324+
ppsci.visualize.plot.save_plot_from_2d_dict(
330325
f"./cylinder_transformer_pred_{i}",
331-
{key: value[i] for key, value in data_dict.items()},
332-
("states", "pred_states"),
326+
{
327+
"pred_ux": output_dict["pred_states"][i][:, 0],
328+
"pred_uy": output_dict["pred_states"][i][:, 1],
329+
"pred_p": output_dict["pred_states"][i][:, 2],
330+
},
331+
("pred_ux", "pred_uy", "pred_p"),
332+
10,
333+
20,
334+
np.linspace(-2, 14, 9),
335+
np.linspace(-4, 4, 5),
333336
)
334337

335338

0 commit comments

Comments
 (0)