Skip to content

Commit e29fd66

Browse files
fix visualization and related 2 utilities (#1186)
1 parent 2d636cb commit e29fd66

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

ppsci/solver/visu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ def visualize_func(solver: "solver.Solver", epoch_id: Optional[int]):
7979
for key, batch_output in batch_output_dict.items():
8080
all_output[key].append(batch_output.detach().astype("float32"))
8181

82-
# concatenate all data
82+
# concatenate all data and convert to numpy array
8383
for key in all_input:
84-
all_input[key] = paddle.concat(all_input[key])
84+
all_input[key] = paddle.concat(all_input[key]).numpy()
8585
for key in all_output:
86-
all_output[key] = paddle.concat(all_output[key])
86+
all_output[key] = paddle.concat(all_output[key]).numpy()
8787

8888
# save visualization
8989
with misc.RankZeroOnly(solver.rank) as is_master:

ppsci/visualize/vtu.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ def save_vtu_from_dict(
133133
>>> value_keys = ("u","v")
134134
>>> ppsci.visualize.save_vtu_from_dict(filename, data_dict, coord_keys, value_keys) # doctest: +SKIP
135135
"""
136-
if len(coord_keys) not in [2, 3, 4]:
137-
raise ValueError(f"ndim of coord ({len(coord_keys)}) should be 2, 3 or 4")
136+
spatial_coord_keys = [key for key in coord_keys if key not in ("t", "sdf")]
137+
if len(spatial_coord_keys) not in [1, 2, 3]:
138+
raise ValueError(
139+
f"ndim of spatial coord ({len(spatial_coord_keys)}) should be 1, 2, or 3"
140+
)
138141

139-
coord = [data_dict[k] for k in coord_keys if k not in ("t", "sdf")]
142+
coord = [data_dict[k] for k in spatial_coord_keys]
140143
value = [data_dict[k] for k in value_keys] if value_keys else None
141144

142145
coord = np.concatenate(coord, axis=1)
@@ -180,10 +183,13 @@ def save_vtp_from_dict(
180183
"""
181184
import pyvista as pv
182185

183-
if len(coord_keys) not in [3]:
184-
raise ValueError(f"ndim of coord ({len(coord_keys)}) should be 3 in vtp format")
186+
spatial_coord_keys = [key for key in coord_keys if key not in ("t", "sdf")]
187+
if len(spatial_coord_keys) not in [3]:
188+
raise ValueError(
189+
f"ndim of spatial coord ({len(spatial_coord_keys)}) should be 3 in vtp format"
190+
)
185191

186-
coord = [data_dict[k] for k in coord_keys if k not in ("t", "sdf")]
192+
coord = [data_dict[k] for k in spatial_coord_keys]
187193
assert all([c.ndim == 2 for c in coord]), "array of each axis should be [*, 1]"
188194
coord = np.concatenate(coord, axis=1)
189195

0 commit comments

Comments
 (0)