Skip to content

Commit f0e7771

Browse files
committed
TL: added vtk support for 3D VTR files
1 parent bf7e7d2 commit f0e7771

File tree

4 files changed

+131
-2
lines changed

4 files changed

+131
-2
lines changed

etc/environment-base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies:
1010
- sympy>=1.0
1111
- numba>=0.35
1212
- dill>=0.2.6
13+
- vtk
1314
- pip
1415
- pip:
1516
- qmat>=0.1.8

pySDC/helpers/fieldsIO.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
To use MPI collective writing, you need to call first the class methods :class:`Rectilinear.initMPI` (cf their docstring).
4444
Also, `Rectilinear.setHeader` **must be given the global grids coordinates**, wether the code is run in parallel or not.
4545
46-
Also : this feature is only available with Python 3.11 or higher ...
46+
> ⚠️ Also : this module can only be imported with **Python 3.11 or higher** !
4747
"""
4848
import os
4949
import numpy as np
@@ -640,7 +640,9 @@ def readField(self, idx):
640640
return t, field
641641

642642

643-
# Utility function used for testing
643+
# -----------------------------------------------------------------------------------------------
644+
# Utility functions used for testing
645+
# -----------------------------------------------------------------------------------------------
644646
def initGrid(nVar, gridSizes):
645647
dim = len(gridSizes)
646648
coords = [np.linspace(0, 1, num=n, endpoint=False) for n in gridSizes]

pySDC/helpers/vtk.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Helper functions to write and read fields from VTK files (to be used with Paraview or PyVista)
5+
"""
6+
import os
7+
import vtk
8+
from vtkmodules.util import numpy_support
9+
import numpy as np
10+
11+
12+
def writeToVTR(fileName: str, data, coords, varNames):
13+
"""
14+
Write a data array containing variables from a 3D rectilinear grid into a VTR file.
15+
16+
Parameters
17+
----------
18+
fileName : str
19+
Name of the VTR file, can be with or without the .vtr extension.
20+
data : np.4darray
21+
Array containing all the variables with [nVar, nX, nY, nZ] shape.
22+
coords : list[np.1darray]
23+
Coordinates in each direction.
24+
varNames : list[str]
25+
Variable names.
26+
27+
Returns
28+
-------
29+
fileName : str
30+
Name of the VTR file.
31+
"""
32+
data = np.asarray(data)
33+
nVar, *gridSizes = data.shape
34+
35+
assert len(gridSizes) == 3, "function can be used only for 3D grid data"
36+
assert nVar == len(varNames), "varNames must have as many variable as data"
37+
assert [len(np.ravel(coord)) for coord in coords] == gridSizes, "coordinate size incompatible with data shape"
38+
39+
nX, nY, nZ = gridSizes
40+
vtr = vtk.vtkRectilinearGrid()
41+
vtr.SetDimensions(nX, nY, nZ)
42+
43+
vect = lambda x: numpy_support.numpy_to_vtk(num_array=x, deep=True, array_type=vtk.VTK_FLOAT)
44+
x, y, z = coords
45+
vtr.SetXCoordinates(vect(x))
46+
vtr.SetYCoordinates(vect(y))
47+
vtr.SetZCoordinates(vect(z))
48+
49+
field = lambda u: numpy_support.numpy_to_vtk(num_array=u.ravel(order='F'), deep=True, array_type=vtk.VTK_FLOAT)
50+
pointData = vtr.GetPointData()
51+
for name, u in zip(varNames, data):
52+
uVTK = field(u)
53+
uVTK.SetName(name)
54+
pointData.AddArray(uVTK)
55+
56+
writer = vtk.vtkXMLRectilinearGridWriter()
57+
if not fileName.endswith(".vtr"):
58+
fileName += ".vtr"
59+
writer.SetFileName(fileName)
60+
writer.SetInputData(vtr)
61+
writer.Write()
62+
63+
return fileName
64+
65+
66+
def readFromVTR(fileName: str):
67+
"""
68+
Read a VTR file into a numpy 4darray
69+
70+
Parameters
71+
----------
72+
fileName : str
73+
Name of the VTR file, can be with or without the .vtr extension.
74+
75+
Returns
76+
-------
77+
data : np.4darray
78+
Array containing all the variables with [nVar, nX, nY, nZ] shape.
79+
coords : list[np.1darray]
80+
Coordinates in each direction.
81+
varNames : list[str]
82+
Variable names.
83+
"""
84+
if not fileName.endswith(".vtr"):
85+
fileName += ".vtr"
86+
assert os.path.isfile(fileName), f"{fileName} does not exist"
87+
88+
reader = vtk.vtkXMLRectilinearGridReader()
89+
reader.SetFileName(fileName)
90+
reader.Update()
91+
92+
vtr = reader.GetOutput()
93+
dims = vtr.GetDimensions()
94+
assert len(dims) == 3, "can only read 3D data"
95+
96+
vect = lambda x: numpy_support.vtk_to_numpy(x)
97+
coords = [vect(vtr.GetXCoordinates()), vect(vtr.GetYCoordinates()), vect(vtr.GetZCoordinates())]
98+
pointData = vtr.GetPointData()
99+
varNames = [pointData.GetArrayName(i) for i in range(pointData.GetNumberOfArrays())]
100+
data = [numpy_support.vtk_to_numpy(pointData.GetArray(name)).reshape(dims, order="F") for name in varNames]
101+
data = np.array(data)
102+
103+
return data, coords, varNames
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import numpy as np
3+
4+
5+
@pytest.mark.parametrize("nZ", [1, 5, 16])
6+
@pytest.mark.parametrize("nY", [1, 5, 16])
7+
@pytest.mark.parametrize("nX", [1, 5, 16])
8+
@pytest.mark.parametrize("nVar", [1, 2, 3])
9+
def testVTR(nVar, nX, nY, nZ):
10+
from pySDC.helpers.vtk import writeToVTR, readFromVTR
11+
12+
data1 = np.random.rand(nVar, nX, nY, nZ)
13+
coords1 = [np.sort(np.random.rand(n)) for n in [nX, nY, nZ]]
14+
varNames1 = [f"var{i}" for i in range(nVar)]
15+
16+
data2, coords2, varNames2 = readFromVTR(writeToVTR("testVTR", data1, coords1, varNames1))
17+
18+
for i, (x1, x2) in enumerate(zip(coords1, coords2)):
19+
print(x1, x2)
20+
assert np.allclose(x1, x2), f"coordinate mismatch in dir. {i}"
21+
assert varNames1 == varNames2, f"varNames mismatch"
22+
assert data1.shape == data2.shape, f"data shape mismatch"
23+
assert np.allclose(data1, data2), f"data values mismatch"

0 commit comments

Comments
 (0)