Skip to content

Commit f14e106

Browse files
support saving point cloud to vtp format (#1122)
1 parent 8077c51 commit f14e106

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

ppsci/visualize/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ppsci.visualize.visualizer import VisualizerWeather # isort:skip
2929
from ppsci.visualize.radar import VisualizerRadar # isort:skip
3030
from ppsci.visualize.vtu import save_vtu_from_dict # isort:skip
31+
from ppsci.visualize.vtu import save_vtp_from_dict # isort:skip
3132
from ppsci.visualize.plot import save_plot_from_1d_dict # isort:skip
3233
from ppsci.visualize.plot import save_plot_from_3d_dict # isort:skip
3334
from ppsci.visualize.plot import save_plot_weather_from_dict # isort:skip
@@ -44,6 +45,7 @@
4445
"VisualizerWeather",
4546
"VisualizerRadar",
4647
"save_vtu_from_dict",
48+
"save_vtp_from_dict",
4749
"save_vtu_to_mesh",
4850
"save_plot_from_1d_dict",
4951
"save_plot_from_3d_dict",

ppsci/visualize/vtu.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,92 @@ def save_vtu_from_dict(
147147
_save_vtu_from_array(filename, coord, value, value_keys, num_timestamps)
148148

149149

150+
def save_vtp_from_dict(
151+
filename: str,
152+
data_dict: Dict[str, np.ndarray],
153+
coord_keys: Tuple[str, ...],
154+
value_keys: Tuple[str, ...],
155+
num_timestamps: int = 1,
156+
):
157+
"""Save dict data to '*.vtp' file.
158+
159+
Args:
160+
filename (str): Output filename.
161+
data_dict (Dict[str, np.ndarray]): Data in dict.
162+
coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
163+
value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
164+
num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
165+
166+
Examples:
167+
>>> import ppsci
168+
>>> import numpy as np
169+
>>> filename = "path/to/file.vtp"
170+
>>> data_dict = {
171+
... "x": np.array([[1], [2], [3],[4]]),
172+
... "y": np.array([[2], [3], [4],[4]]),
173+
... "z": np.array([[3], [4], [5],[4]]),
174+
... "u": np.array([[4], [5], [6],[4]]),
175+
... "v": np.array([[5], [6], [7],[4]]),
176+
... }
177+
>>> coord_keys = ("x","y","z")
178+
>>> value_keys = ("u","v")
179+
>>> ppsci.visualize.save_vtp_from_dict(filename, data_dict, coord_keys, value_keys) # doctest: +SKIP
180+
"""
181+
import pyvista as pv
182+
183+
if len(coord_keys) not in [3]:
184+
raise ValueError(f"ndim of coord ({len(coord_keys)}) should be 3 in vtp format")
185+
186+
coord = [data_dict[k] for k in coord_keys if k not in ("t", "sdf")]
187+
assert all([c.ndim == 2 for c in coord]), "array of each axis should be [*, 1]"
188+
coord = np.concatenate(coord, axis=1)
189+
190+
if not isinstance(coord, np.ndarray):
191+
raise ValueError(f"type of coord({type(coord)}) should be ndarray.")
192+
if len(coord) % num_timestamps != 0:
193+
raise ValueError(
194+
f"coord length({len(coord)}) should be an integer multiple of "
195+
f"num_timestamps({num_timestamps})"
196+
)
197+
if coord.shape[1] not in [3]:
198+
raise ValueError(f"ndim of coord({coord.shape[1]}) should be 3 in vtp format.")
199+
200+
if len(os.path.dirname(filename)):
201+
os.makedirs(os.path.dirname(filename), exist_ok=True)
202+
203+
npoint = len(coord)
204+
nx = npoint // num_timestamps
205+
if filename.endswith(".vtp"):
206+
filename = filename[:-4]
207+
208+
for t in range(num_timestamps):
209+
coord_ = coord[t * nx : (t + 1) * nx]
210+
point_cloud = pv.PolyData(coord_)
211+
for k in value_keys:
212+
value_ = data_dict[k][t * nx : (t + 1) * nx]
213+
if value_ is not None and not isinstance(value_, np.ndarray):
214+
raise ValueError(f"type of value({type(value_)}) should be ndarray.")
215+
if value_ is not None and len(coord_) != len(value_):
216+
raise ValueError(
217+
f"coord length({len(coord_)}) should be equal to value length({len(value_)})"
218+
)
219+
point_cloud[k] = value_
220+
221+
if num_timestamps > 1:
222+
width = len(str(num_timestamps - 1))
223+
point_cloud.save(f"{filename}_t-{t:0{width}}.vtp")
224+
else:
225+
point_cloud.save(f"{filename}.vtp")
226+
227+
if num_timestamps > 1:
228+
logger.message(
229+
f"Visualization results are saved to: {filename}_t-{0:0{width}}.vtp ~ "
230+
f"{filename}_t-{num_timestamps - 1:0{width}}.vtp"
231+
)
232+
else:
233+
logger.message(f"Visualization result is saved to: {filename}.vtp")
234+
235+
150236
def save_vtu_to_mesh(
151237
filename: str,
152238
data_dict: Dict[str, np.ndarray],

0 commit comments

Comments
 (0)