Skip to content

Commit 01f9987

Browse files
committed
Add function to read CG1 functions from VTKHDF file
1 parent 6ce5d1f commit 01f9987

File tree

4 files changed

+218
-2
lines changed

4 files changed

+218
-2
lines changed

cpp/dolfinx/io/VTKHDF.h

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,4 +642,138 @@ std::vector<U> read_data(std::string point_or_cell, std::string filename,
642642

643643
return values;
644644
}
645+
646+
/// @brief Read a CG1 function from a VTKHDF format file.
647+
///
648+
/// @tparam U Scalar type of mesh
649+
/// @param filename Name of the file to read from.
650+
/// @param mesh Mesh previously read from the same file.
651+
/// @param timestep The time step to read for time-dependent data.
652+
/// @return The function read from file.
653+
///
654+
/// @note This only supports meshes with a single cell type as of now.
655+
template <std::floating_point U>
656+
fem::Function<U> read_CG1_function(std::string filename,
657+
std::shared_ptr<mesh::Mesh<U>> mesh,
658+
std::int32_t timestep = 0)
659+
{
660+
auto element = basix::create_element<U>(
661+
basix::element::family::P,
662+
mesh::cell_type_to_basix_type(mesh->topology()->cell_type()), 1,
663+
basix::element::lagrange_variant::unset,
664+
basix::element::dpc_variant::unset, false);
665+
666+
hid_t h5file = hdf5::open_file(mesh->comm(), filename, "r", true);
667+
std::string dataset_name = "/VTKHDF/PointData/u";
668+
669+
std::vector<std::int64_t> shape
670+
= hdf5::get_dataset_shape(h5file, dataset_name);
671+
hdf5::close_file(h5file);
672+
673+
const std::size_t block_size = shape[1];
674+
675+
auto P1
676+
= std::make_shared<fem::FunctionSpace<U>>(fem::create_functionspace<U>(
677+
mesh, std::make_shared<fem::FiniteElement<U>>(
678+
element, std::vector<std::size_t>{block_size})));
679+
680+
fem::Function<U> u_in(P1);
681+
682+
std::shared_ptr<const common::IndexMap> index_map(
683+
mesh->geometry().index_map());
684+
685+
std::int64_t range0 = index_map->local_range()[0];
686+
687+
int npoints = index_map->size_local();
688+
689+
std::array<std::int64_t, 2> range{range0, range0 + npoints};
690+
691+
const auto values
692+
= io::VTKHDF::read_data("Point", filename, *mesh, range, timestep);
693+
694+
// Parallel distribution. For vector functions we distribute each component
695+
// separately.
696+
std::vector<std::vector<double>> values_s(block_size);
697+
for (auto& v : values_s)
698+
{
699+
v.reserve(values.size() / block_size);
700+
}
701+
for (std::size_t i = 0; i < values.size() / block_size; ++i)
702+
{
703+
for (int j = 0; j < block_size; ++j)
704+
{
705+
values_s[j].push_back(values[block_size * i + j]);
706+
}
707+
}
708+
709+
std::vector<std::int64_t> entities(range[1] - range[0]);
710+
std::iota(entities.begin(), entities.end(), range[0]);
711+
712+
MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
713+
const std::int64_t,
714+
MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
715+
entities_span(entities.data(),
716+
std::array<std::size_t, 2>{entities.size(), 1});
717+
718+
std::vector<std::pair<std::vector<std::int32_t>, std::vector<double>>>
719+
entities_values;
720+
721+
for (std::size_t i = 0; i < values_s.size(); ++i)
722+
{
723+
entities_values.push_back(dolfinx::io::distribute_entity_data<double>(
724+
*mesh->topology(), mesh->geometry().input_global_indices(),
725+
mesh->geometry().index_map()->size_global(),
726+
mesh->geometry().cmap().create_dof_layout(), mesh->geometry().dofmap(),
727+
mesh::cell_dim(mesh::CellType::point), entities_span, values_s[i]));
728+
}
729+
730+
auto num_vertices_per_cell
731+
= dolfinx::mesh::num_cell_vertices(mesh->topology()->cell_type());
732+
std::vector<std::int32_t> local_vertex_map(num_vertices_per_cell);
733+
734+
for (int i = 0; i < num_vertices_per_cell; ++i)
735+
{
736+
const auto v_to_d
737+
= u_in.function_space()->dofmap()->element_dof_layout().entity_dofs(0,
738+
i);
739+
assert(v_to_d.size() == 1);
740+
local_vertex_map[i] = v_to_d.front();
741+
}
742+
743+
const auto tdim = mesh->topology()->dim();
744+
const auto c_to_v = mesh->topology()->connectivity(tdim, 0);
745+
std::vector<std::int32_t> vertex_to_dofmap(
746+
mesh->topology()->index_map(0)->size_local()
747+
+ mesh->topology()->index_map(0)->num_ghosts());
748+
749+
for (int i = 0; i < mesh->topology()->index_map(tdim)->size_local(); ++i)
750+
{
751+
const auto local_vertices = c_to_v->links(i);
752+
const auto local_dofs = u_in.function_space()->dofmap()->cell_dofs(i);
753+
for (int j = 0; j < num_vertices_per_cell; ++j)
754+
{
755+
vertex_to_dofmap[local_vertices[j]] = local_dofs[local_vertex_map[j]];
756+
}
757+
}
758+
759+
/*
760+
* After the data is read and distributed, we need to place the
761+
* retrieved values in the correct position in the function's array,
762+
* reading values and positions from `entities_values`.
763+
*/
764+
for (std::size_t i = 0; i < entities_values[0].first.size(); ++i)
765+
{
766+
for (std::size_t j = 0; j < block_size; ++j)
767+
{
768+
u_in.x()
769+
->array()[block_size * vertex_to_dofmap[entities_values[0].first[i]]
770+
+ j]
771+
= entities_values[j].second[i];
772+
}
773+
}
774+
775+
u_in.x()->scatter_fwd();
776+
777+
return u_in;
778+
}
645779
} // namespace dolfinx::io::VTKHDF

python/dolfinx/io/vtkhdf.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,22 @@
1212
import numpy.typing as npt
1313

1414
import basix
15+
import dolfinx
1516
import ufl
1617
from dolfinx.cpp.io import (
18+
read_vtkhdf_cg1_function,
1719
read_vtkhdf_mesh_float32,
1820
read_vtkhdf_mesh_float64,
1921
write_vtkhdf_cg1_function,
2022
write_vtkhdf_data,
2123
write_vtkhdf_function,
2224
write_vtkhdf_mesh,
2325
)
24-
from dolfinx.fem import Function
26+
from dolfinx.fem import Function, FunctionSpace
2527
from dolfinx.mesh import Mesh
2628

2729
__all__ = [
30+
"read_cg1_function",
2831
"read_mesh",
2932
"write_cell_data",
3033
"write_cg1_function",
@@ -122,3 +125,32 @@ def write_function(filename: str | Path, mesh: Mesh, u: Function, time: float):
122125
time: Timestamp.
123126
"""
124127
write_vtkhdf_function(filename, mesh._cpp_object, u._cpp_object, time)
128+
129+
def read_cg1_function(filename: str | Path, mesh: Mesh, timestep: int = 0) -> Function:
130+
"""Read a CG1 function form file.
131+
132+
Args:
133+
filename: File to read.
134+
mesh: Mesh.
135+
timestep: Time step to read
136+
137+
Returns:
138+
The read function.
139+
"""
140+
match type(mesh._cpp_object):
141+
case dolfinx.cpp.mesh.Mesh_float64:
142+
dtype = np.float64
143+
case dolfinx.cpp.mesh.Mesh_float32:
144+
dtype = np.float32
145+
u_cpp = read_vtkhdf_cg1_function(filename, mesh._cpp_object, timestep)
146+
elem = basix.ufl.element(
147+
"Lagrange",
148+
mesh.basix_cell(),
149+
1,
150+
shape=tuple(u_cpp.function_space.element.value_shape),
151+
dtype=dtype,
152+
)
153+
V = FunctionSpace(mesh, elem, u_cpp.function_space)
154+
u_ret = Function(V, dtype=dtype)
155+
u_ret.interpolate(u_cpp)
156+
return u_ret

python/dolfinx/wrappers/io.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ void io(nb::module_& m)
227227
m.def("write_vtkhdf_data", &dolfinx::io::VTKHDF::write_data<float>);
228228
m.def("write_vtkhdf_function", &dolfinx::io::VTKHDF::write_function<float>);
229229
m.def("write_vtkhdf_function", &dolfinx::io::VTKHDF::write_function<double>);
230+
m.def("read_vtkhdf_cg1_function", &dolfinx::io::VTKHDF::read_CG1_function<float>);
231+
m.def("read_vtkhdf_cg1_function", &dolfinx::io::VTKHDF::read_CG1_function<double>);
230232
m.def("read_vtkhdf_mesh_float64",
231233
[](MPICommWrapper comm, const std::string& filename, std::size_t gdim,
232234
std::optional<std::int32_t> max_facet_to_cell_links)

python/test/unit/io/test_vtkhdf.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,17 @@
99
import numpy as np
1010
import pytest
1111

12+
import basix
1213
import dolfinx
1314
import ufl
14-
from dolfinx.io.vtkhdf import read_mesh, write_cell_data, write_mesh, write_point_data
15+
from dolfinx.io.vtkhdf import (
16+
read_cg1_function,
17+
read_mesh,
18+
write_cell_data,
19+
write_cg1_function,
20+
write_mesh,
21+
write_point_data,
22+
)
1523
from dolfinx.mesh import CellType, Mesh, create_unit_cube, create_unit_square
1624

1725

@@ -206,3 +214,43 @@ def test_write_mixed_topology_data(mixed_topology_mesh):
206214
b = sum([im.size_local for im in mesh.topology.index_maps(mesh.topology.dim)])
207215
cell_data = np.arange(b, dtype=np.float64)
208216
write_cell_data(filename, mesh, cell_data, 0.0)
217+
218+
219+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
220+
@pytest.mark.parametrize("width", [1, 3])
221+
def test_write_read_cg1_function(dtype, width):
222+
mesh = create_unit_cube(MPI.COMM_WORLD, 5, 5, 5, dtype=dtype)
223+
elem = basix.ufl.element("Lagrange", mesh.basix_cell(), 1, shape=(width,), dtype=dtype)
224+
V = dolfinx.fem.functionspace(mesh, elem)
225+
f = dolfinx.fem.Function(V, dtype=dtype)
226+
f.interpolate(lambda x: np.vstack([np.sin(x[0]) ** 2 + np.cos(x[1] ** 2 * x[2])] * width))
227+
228+
filename = f"cg1_fun_{width}.vtkhdf"
229+
write_mesh(filename, mesh)
230+
write_cg1_function(filename, mesh, f, 0)
231+
232+
mesh2 = read_mesh(mesh.comm, filename, dtype=dtype)
233+
f_2 = read_cg1_function(filename, mesh2)
234+
235+
filename2 = f"cg1_fun2_{width}.vtkhdf"
236+
write_mesh(filename2, mesh2)
237+
write_cg1_function(filename2, mesh2, f_2, 0)
238+
239+
cell_map = mesh.topology.index_map(mesh.topology.dim)
240+
num_cells_on_proc = cell_map.size_local + cell_map.num_ghosts
241+
cells = np.arange(num_cells_on_proc, dtype=np.int32)
242+
padding = 1e-14 if dtype == np.float64 else 1e-5
243+
interpolation_data = dolfinx.fem.create_interpolation_data(
244+
V, f_2.function_space, cells, padding=padding
245+
)
246+
247+
f_i = dolfinx.fem.Function(V, dtype=dtype)
248+
f_i.interpolate_nonmatching(f_2, cells, interpolation_data=interpolation_data)
249+
f_i.x.scatter_forward()
250+
251+
filename_i = f"cg1_funi_{width}.vtkhdf"
252+
write_mesh(filename_i, mesh)
253+
write_cg1_function(filename_i, mesh, f_i, 0)
254+
255+
tol = 1e-14 if dtype == np.float64 else 1e-6
256+
np.testing.assert_allclose(f.x.array, f_i.x.array, rtol=tol)

0 commit comments

Comments
 (0)