Skip to content

Commit a03b948

Browse files
committed
Add function to read CG1 functions from VTKHDF file
1 parent 46a530e commit a03b948

File tree

4 files changed

+219
-2
lines changed

4 files changed

+219
-2
lines changed

cpp/dolfinx/io/VTKHDF.h

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,4 +643,138 @@ std::vector<U> read_data(std::string point_or_cell, std::string filename,
643643

644644
return values;
645645
}
646+
647+
/// @brief Read a CG1 function from a VTKHDF format file.
648+
///
649+
/// @tparam U Scalar type of mesh
650+
/// @param filename Name of the file to read from.
651+
/// @param mesh Mesh previously read from the same file.
652+
/// @param timestep The time step to read for time-dependent data.
653+
/// @return The function read from file.
654+
///
655+
/// @note This only supports meshes with a single cell type as of now.
656+
template <std::floating_point U>
657+
fem::Function<U> read_CG1_function(std::string filename,
658+
std::shared_ptr<mesh::Mesh<U>> mesh,
659+
std::int32_t timestep = 0)
660+
{
661+
auto element = basix::create_element<U>(
662+
basix::element::family::P,
663+
mesh::cell_type_to_basix_type(mesh->topology()->cell_type()), 1,
664+
basix::element::lagrange_variant::unset,
665+
basix::element::dpc_variant::unset, false);
666+
667+
hid_t h5file = hdf5::open_file(mesh->comm(), filename, "r", true);
668+
std::string dataset_name = "/VTKHDF/PointData/u";
669+
670+
std::vector<std::int64_t> shape
671+
= hdf5::get_dataset_shape(h5file, dataset_name);
672+
hdf5::close_file(h5file);
673+
674+
const std::size_t block_size = shape[1];
675+
676+
auto P1
677+
= std::make_shared<fem::FunctionSpace<U>>(fem::create_functionspace<U>(
678+
mesh, std::make_shared<fem::FiniteElement<U>>(
679+
element, std::vector<std::size_t>{block_size})));
680+
681+
fem::Function<U> u_in(P1);
682+
683+
std::shared_ptr<const common::IndexMap> index_map(
684+
mesh->geometry().index_map());
685+
686+
std::int64_t range0 = index_map->local_range()[0];
687+
688+
int npoints = index_map->size_local();
689+
690+
std::array<std::int64_t, 2> range{range0, range0 + npoints};
691+
692+
const auto values
693+
= io::VTKHDF::read_data("Point", filename, *mesh, range, timestep);
694+
695+
// Parallel distribution. For vector functions we distribute each component
696+
// separately.
697+
std::vector<std::vector<double>> values_s(block_size);
698+
for (auto& v : values_s)
699+
{
700+
v.reserve(values.size() / block_size);
701+
}
702+
for (std::size_t i = 0; i < values.size() / block_size; ++i)
703+
{
704+
for (int j = 0; j < block_size; ++j)
705+
{
706+
values_s[j].push_back(values[block_size * i + j]);
707+
}
708+
}
709+
710+
std::vector<std::int64_t> entities(range[1] - range[0]);
711+
std::iota(entities.begin(), entities.end(), range[0]);
712+
713+
MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
714+
const std::int64_t,
715+
MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
716+
entities_span(entities.data(),
717+
std::array<std::size_t, 2>{entities.size(), 1});
718+
719+
std::vector<std::pair<std::vector<std::int32_t>, std::vector<double>>>
720+
entities_values;
721+
722+
for (std::size_t i = 0; i < values_s.size(); ++i)
723+
{
724+
entities_values.push_back(dolfinx::io::distribute_entity_data<double>(
725+
*mesh->topology(), mesh->geometry().input_global_indices(),
726+
mesh->geometry().index_map()->size_global(),
727+
mesh->geometry().cmap().create_dof_layout(), mesh->geometry().dofmap(),
728+
mesh::cell_dim(mesh::CellType::point), entities_span, values_s[i]));
729+
}
730+
731+
auto num_vertices_per_cell
732+
= dolfinx::mesh::num_cell_vertices(mesh->topology()->cell_type());
733+
std::vector<std::int32_t> local_vertex_map(num_vertices_per_cell);
734+
735+
for (int i = 0; i < num_vertices_per_cell; ++i)
736+
{
737+
const auto v_to_d
738+
= u_in.function_space()->dofmap()->element_dof_layout().entity_dofs(0,
739+
i);
740+
assert(v_to_d.size() == 1);
741+
local_vertex_map[i] = v_to_d.front();
742+
}
743+
744+
const auto tdim = mesh->topology()->dim();
745+
const auto c_to_v = mesh->topology()->connectivity(tdim, 0);
746+
std::vector<std::int32_t> vertex_to_dofmap(
747+
mesh->topology()->index_map(0)->size_local()
748+
+ mesh->topology()->index_map(0)->num_ghosts());
749+
750+
for (int i = 0; i < mesh->topology()->index_map(tdim)->size_local(); ++i)
751+
{
752+
const auto local_vertices = c_to_v->links(i);
753+
const auto local_dofs = u_in.function_space()->dofmap()->cell_dofs(i);
754+
for (int j = 0; j < num_vertices_per_cell; ++j)
755+
{
756+
vertex_to_dofmap[local_vertices[j]] = local_dofs[local_vertex_map[j]];
757+
}
758+
}
759+
760+
/*
761+
* After the data is read and distributed, we need to place the
762+
* retrieved values in the correct position in the function's array,
763+
* reading values and positions from `entities_values`.
764+
*/
765+
for (std::size_t i = 0; i < entities_values[0].first.size(); ++i)
766+
{
767+
for (std::size_t j = 0; j < block_size; ++j)
768+
{
769+
u_in.x()
770+
->array()[block_size * vertex_to_dofmap[entities_values[0].first[i]]
771+
+ j]
772+
= entities_values[j].second[i];
773+
}
774+
}
775+
776+
u_in.x()->scatter_fwd();
777+
778+
return u_in;
779+
}
646780
} // namespace dolfinx::io::VTKHDF

python/dolfinx/io/vtkhdf.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,22 @@
1212
import numpy.typing as npt
1313

1414
import basix
15+
import dolfinx
1516
import ufl
1617
from dolfinx.cpp.io import (
18+
read_vtkhdf_cg1_function,
1719
read_vtkhdf_mesh_float32,
1820
read_vtkhdf_mesh_float64,
1921
write_vtkhdf_cg1_function,
2022
write_vtkhdf_data,
2123
write_vtkhdf_dg0_function,
2224
write_vtkhdf_mesh,
2325
)
24-
from dolfinx.fem import Function
26+
from dolfinx.fem import Function, FunctionSpace
2527
from dolfinx.mesh import Mesh
2628

2729
__all__ = [
30+
"read_cg1_function",
2831
"read_mesh",
2932
"write_cell_data",
3033
"write_cg1_function",
@@ -134,3 +137,33 @@ def write_dg0_function(filename: str | Path, mesh: Mesh, u: Function, time: floa
134137
time: Timestamp.
135138
"""
136139
write_vtkhdf_dg0_function(filename, mesh._cpp_object, u._cpp_object, time)
140+
141+
142+
def read_cg1_function(filename: str | Path, mesh: Mesh, timestep: int = 0) -> Function:
143+
"""Read a CG1 function form file.
144+
145+
Args:
146+
filename: File to read.
147+
mesh: Mesh.
148+
timestep: Time step to read
149+
150+
Returns:
151+
The read function.
152+
"""
153+
match type(mesh._cpp_object):
154+
case dolfinx.cpp.mesh.Mesh_float64:
155+
dtype = np.float64
156+
case dolfinx.cpp.mesh.Mesh_float32:
157+
dtype = np.float32
158+
u_cpp = read_vtkhdf_cg1_function(filename, mesh._cpp_object, timestep)
159+
elem = basix.ufl.element(
160+
"Lagrange",
161+
mesh.basix_cell(),
162+
1,
163+
shape=tuple(u_cpp.function_space.element.value_shape),
164+
dtype=dtype,
165+
)
166+
V = FunctionSpace(mesh, elem, u_cpp.function_space)
167+
u_ret = Function(V, dtype=dtype)
168+
u_ret.interpolate(u_cpp)
169+
return u_ret

python/dolfinx/wrappers/io.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ void io(nb::module_& m)
229229
m.def("write_vtkhdf_cg1_function", &dolfinx::io::VTKHDF::write_CG1_function<double>);
230230
m.def("write_vtkhdf_dg0_function", &dolfinx::io::VTKHDF::write_DG0_function<float>);
231231
m.def("write_vtkhdf_dg0_function", &dolfinx::io::VTKHDF::write_DG0_function<double>);
232+
m.def("read_vtkhdf_cg1_function", &dolfinx::io::VTKHDF::read_CG1_function<float>);
233+
m.def("read_vtkhdf_cg1_function", &dolfinx::io::VTKHDF::read_CG1_function<double>);
232234
m.def("read_vtkhdf_mesh_float64",
233235
[](MPICommWrapper comm, const std::string& filename, std::size_t gdim,
234236
std::optional<std::int32_t> max_facet_to_cell_links)

python/test/unit/io/test_vtkhdf.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,17 @@
99
import numpy as np
1010
import pytest
1111

12+
import basix
1213
import dolfinx
1314
import ufl
14-
from dolfinx.io.vtkhdf import read_mesh, write_cell_data, write_mesh, write_point_data
15+
from dolfinx.io.vtkhdf import (
16+
read_cg1_function,
17+
read_mesh,
18+
write_cell_data,
19+
write_cg1_function,
20+
write_mesh,
21+
write_point_data,
22+
)
1523
from dolfinx.mesh import CellType, Mesh, create_unit_cube, create_unit_square
1624

1725

@@ -206,3 +214,43 @@ def test_write_mixed_topology_data(mixed_topology_mesh):
206214
b = sum([im.size_local for im in mesh.topology.index_maps(mesh.topology.dim)])
207215
cell_data = np.arange(b, dtype=np.float64)
208216
write_cell_data(filename, mesh, cell_data, 0.0)
217+
218+
219+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
220+
@pytest.mark.parametrize("width", [1, 3])
221+
def test_write_read_cg1_function(dtype, width):
222+
mesh = create_unit_cube(MPI.COMM_WORLD, 5, 5, 5, dtype=dtype)
223+
elem = basix.ufl.element("Lagrange", mesh.basix_cell(), 1, shape=(width,), dtype=dtype)
224+
V = dolfinx.fem.functionspace(mesh, elem)
225+
f = dolfinx.fem.Function(V, dtype=dtype)
226+
f.interpolate(lambda x: np.vstack([np.sin(x[0]) ** 2 + np.cos(x[1] ** 2 * x[2])] * width))
227+
228+
filename = f"cg1_fun_{width}.vtkhdf"
229+
write_mesh(filename, mesh)
230+
write_cg1_function(filename, mesh, f, 0)
231+
232+
mesh2 = read_mesh(mesh.comm, filename, dtype=dtype)
233+
f_2 = read_cg1_function(filename, mesh2)
234+
235+
filename2 = f"cg1_fun2_{width}.vtkhdf"
236+
write_mesh(filename2, mesh2)
237+
write_cg1_function(filename2, mesh2, f_2, 0)
238+
239+
cell_map = mesh.topology.index_map(mesh.topology.dim)
240+
num_cells_on_proc = cell_map.size_local + cell_map.num_ghosts
241+
cells = np.arange(num_cells_on_proc, dtype=np.int32)
242+
padding = 1e-14 if dtype == np.float64 else 1e-5
243+
interpolation_data = dolfinx.fem.create_interpolation_data(
244+
V, f_2.function_space, cells, padding=padding
245+
)
246+
247+
f_i = dolfinx.fem.Function(V, dtype=dtype)
248+
f_i.interpolate_nonmatching(f_2, cells, interpolation_data=interpolation_data)
249+
f_i.x.scatter_forward()
250+
251+
filename_i = f"cg1_funi_{width}.vtkhdf"
252+
write_mesh(filename_i, mesh)
253+
write_cg1_function(filename_i, mesh, f_i, 0)
254+
255+
tol = 1e-14 if dtype == np.float64 else 1e-6
256+
np.testing.assert_allclose(f.x.array, f_i.x.array, rtol=tol)

0 commit comments

Comments
 (0)