Skip to content

Commit fa84113

Browse files
committed
Pass custom_data through assemblers to kernel functions.
1 parent f1daede commit fa84113

File tree

6 files changed

+419
-48
lines changed

6 files changed

+419
-48
lines changed

cpp/dolfinx/fem/Form.h

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ struct integral_data
5555
/// @param[in] entities Indices of entities to integrate over.
5656
/// @param[in] coeffs Indices of the coefficients that are present
5757
/// (active) in `kernel`.
58+
/// @param[in] custom_data Optional custom user data pointer passed to
59+
/// the kernel function.
5860
template <typename K, typename V, typename W>
5961
requires std::is_convertible_v<
6062
std::remove_cvref_t<K>,
@@ -64,9 +66,10 @@ struct integral_data
6466
std::vector<std::int32_t>>
6567
and std::is_convertible_v<std::remove_cvref_t<W>,
6668
std::vector<int>>
67-
integral_data(K&& kernel, V&& entities, W&& coeffs)
69+
integral_data(K&& kernel, V&& entities, W&& coeffs,
70+
void* custom_data = nullptr)
6871
: kernel(std::forward<K>(kernel)), entities(std::forward<V>(entities)),
69-
coeffs(std::forward<W>(coeffs))
72+
coeffs(std::forward<W>(coeffs)), custom_data(custom_data)
7073
{
7174
}
7275

@@ -82,6 +85,11 @@ struct integral_data
8285
/// @brief Indices of coefficients (from the form) that are in this
8386
/// integral.
8487
std::vector<int> coeffs;
88+
89+
/// @brief Custom user data pointer passed to the kernel function.
90+
/// This can be used to pass runtime-computed data (e.g., per-cell
91+
/// quadrature rules, material properties) to the kernel.
92+
void* custom_data = nullptr;
8593
};
8694

8795
/// @brief A representation of finite element variational forms.
@@ -391,6 +399,39 @@ class Form
391399
return it->second.kernel;
392400
}
393401

402+
/// @brief Get the custom data pointer for an integral.
403+
///
404+
/// The custom data pointer is passed to the kernel function during
405+
/// assembly. This can be used to pass runtime-computed data to
406+
/// kernels (e.g., per-cell quadrature rules, material properties).
407+
///
408+
/// @param[in] type Integral type.
409+
/// @param[in] id Integral subdomain ID.
410+
/// @param[in] kernel_idx Index of the kernel (we may have multiple
411+
/// kernels for a given ID in mixed-topology meshes).
412+
/// @return Custom data pointer for the integral, or nullptr if not set.
413+
void* custom_data(IntegralType type, int id, int kernel_idx) const
414+
{
415+
auto it = _integrals.find({type, id, kernel_idx});
416+
if (it == _integrals.end())
417+
throw std::runtime_error("Requested integral not found.");
418+
return it->second.custom_data;
419+
}
420+
421+
/// @brief Set the custom data pointer for an integral.
422+
///
423+
/// @param[in] type Integral type.
424+
/// @param[in] id Integral subdomain ID.
425+
/// @param[in] kernel_idx Index of the kernel.
426+
/// @param[in] data Custom data pointer to set.
427+
void set_custom_data(IntegralType type, int id, int kernel_idx, void* data)
428+
{
429+
auto it = _integrals.find({type, id, kernel_idx});
430+
if (it == _integrals.end())
431+
throw std::runtime_error("Requested integral not found.");
432+
it->second.custom_data = data;
433+
}
434+
394435
/// @brief Get types of integrals in the form.
395436
/// @return Integrals types.
396437
std::set<IntegralType> integral_types() const

cpp/dolfinx/fem/assemble_matrix_impl.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ using mdspan2_t = md::mdspan<const std::int32_t, md::dextents<std::size_t, 2>>;
6060
/// function mesh.
6161
/// @param cell_info1 Cell permutation information for the trial
6262
/// function mesh.
63+
/// @param custom_data Custom user data pointer passed to the kernel.
6364
template <dolfinx::scalar T>
6465
void assemble_cells_matrix(
6566
la::MatSet<T> auto mat_set, mdspan2_t x_dofmap,
@@ -74,7 +75,7 @@ void assemble_cells_matrix(
7475
std::span<const std::int8_t> bc1, FEkernel<T> auto kernel,
7576
md::mdspan<const T, md::dextents<std::size_t, 2>> coeffs,
7677
std::span<const T> constants, std::span<const std::uint32_t> cell_info0,
77-
std::span<const std::uint32_t> cell_info1)
78+
std::span<const std::uint32_t> cell_info1, void* custom_data = nullptr)
7879
{
7980
if (cells.empty())
8081
return;
@@ -109,7 +110,7 @@ void assemble_cells_matrix(
109110
// Tabulate tensor
110111
std::ranges::fill(Ae, 0);
111112
kernel(Ae.data(), &coeffs(c, 0), constants.data(), cdofs.data(), nullptr,
112-
nullptr, nullptr);
113+
nullptr, custom_data);
113114

114115
// Compute A = P_0 \tilde{A} P_1^T (dof transformation)
115116
P0(Ae, cell_info0, cell0, ndim1); // B = P0 \tilde{A}
@@ -198,6 +199,7 @@ void assemble_cells_matrix(
198199
/// function mesh.
199200
/// @param[in] perms Entity permutation integer. Empty if entity
200201
/// permutations are not required.
202+
/// @param custom_data Custom user data pointer passed to the kernel.
201203
template <dolfinx::scalar T>
202204
void assemble_entities(
203205
la::MatSet<T> auto mat_set, mdspan2_t x_dofmap,
@@ -221,7 +223,8 @@ void assemble_entities(
221223
md::mdspan<const T, md::dextents<std::size_t, 2>> coeffs,
222224
std::span<const T> constants, std::span<const std::uint32_t> cell_info0,
223225
std::span<const std::uint32_t> cell_info1,
224-
md::mdspan<const std::uint8_t, md::dextents<std::size_t, 2>> perms)
226+
md::mdspan<const std::uint8_t, md::dextents<std::size_t, 2>> perms,
227+
void* custom_data = nullptr)
225228
{
226229
if (entities.empty())
227230
return;
@@ -259,7 +262,7 @@ void assemble_entities(
259262
// Tabulate tensor
260263
std::ranges::fill(Ae, 0);
261264
kernel(Ae.data(), &coeffs(f, 0), constants.data(), cdofs.data(),
262-
&local_entity, &perm, nullptr);
265+
&local_entity, &perm, custom_data);
263266
P0(Ae, cell_info0, cell0, ndim1);
264267
P1T(Ae, cell_info1, cell1, ndim0);
265268

@@ -363,7 +366,8 @@ void assemble_interior_facets(
363366
coeffs,
364367
std::span<const T> constants, std::span<const std::uint32_t> cell_info0,
365368
std::span<const std::uint32_t> cell_info1,
366-
md::mdspan<const std::uint8_t, md::dextents<std::size_t, 2>> perms)
369+
md::mdspan<const std::uint8_t, md::dextents<std::size_t, 2>> perms,
370+
void* custom_data = nullptr)
367371
{
368372
if (facets.empty())
369373
return;
@@ -440,7 +444,7 @@ void assemble_interior_facets(
440444
: std::array{perms(cells[0], local_facet[0]),
441445
perms(cells[1], local_facet[1])};
442446
kernel(Ae.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(),
443-
local_facet.data(), perm.data(), nullptr);
447+
local_facet.data(), perm.data(), custom_data);
444448

445449
// Local element layout is a 2x2 block matrix with structure
446450
//
@@ -605,12 +609,13 @@ void assemble_matrix(
605609
std::span cells0 = a.domain_arg(IntegralType::cell, 0, i, cell_type_idx);
606610
std::span cells1 = a.domain_arg(IntegralType::cell, 1, i, cell_type_idx);
607611
auto& [coeffs, cstride] = coefficients.at({IntegralType::cell, i});
612+
void* custom_data = a.custom_data(IntegralType::cell, i, cell_type_idx);
608613
assert(cells.size() * cstride == coeffs.size());
609614
impl::assemble_cells_matrix(
610615
mat_set, x_dofmap, x, cells, {dofs0, bs0, cells0}, P0,
611616
{dofs1, bs1, cells1}, P1T, bc0, bc1, fn,
612617
md::mdspan(coeffs.data(), cells.size(), cstride), constants,
613-
cell_info0, cell_info1);
618+
cell_info0, cell_info1, custom_data);
614619
}
615620

616621
md::mdspan<const std::uint8_t, md::dextents<std::size_t, 2>> facet_perms;
@@ -646,6 +651,7 @@ void assemble_matrix(
646651
assert(fn);
647652
auto& [coeffs, cstride]
648653
= coefficients.at({IntegralType::interior_facet, i});
654+
void* custom_data = a.custom_data(IntegralType::interior_facet, i, 0);
649655

650656
std::span facets = a.domain(IntegralType::interior_facet, i, 0);
651657
std::span facets0 = a.domain_arg(IntegralType::interior_facet, 0, i, 0);
@@ -661,7 +667,7 @@ void assemble_matrix(
661667
mdspanx22_t(facets1.data(), facets1.size() / 4, 2, 2)},
662668
P1T, bc0, bc1, fn,
663669
mdspanx2x_t(coeffs.data(), facets.size() / 4, 2, cstride), constants,
664-
cell_info0, cell_info1, facet_perms);
670+
cell_info0, cell_info1, facet_perms, custom_data);
665671
}
666672

667673
for (auto itg_type : {fem::IntegralType::exterior_facet,
@@ -688,6 +694,7 @@ void assemble_matrix(
688694
auto fn = a.kernel(itg_type, i, 0);
689695
assert(fn);
690696
auto& [coeffs, cstride] = coefficients.at({itg_type, i});
697+
void* custom_data = a.custom_data(itg_type, i, 0);
691698

692699
std::span e = a.domain(itg_type, i, 0);
693700
mdspanx2_t entities(e.data(), e.size() / 2, 2);
@@ -700,7 +707,7 @@ void assemble_matrix(
700707
mat_set, x_dofmap, x, entities, {dofs0, bs0, entities0}, P0,
701708
{dofs1, bs1, entities1}, P1T, bc0, bc1, fn,
702709
md::mdspan(coeffs.data(), entities.extent(0), cstride), constants,
703-
cell_info0, cell_info1, perms);
710+
cell_info0, cell_info1, perms, custom_data);
704711
}
705712
}
706713
}

cpp/dolfinx/fem/assemble_scalar_impl.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ T assemble_cells(mdspan2_t x_dofmap,
3030
std::span<const std::int32_t> cells, FEkernel<T> auto fn,
3131
std::span<const T> constants,
3232
md::mdspan<const T, md::dextents<std::size_t, 2>> coeffs,
33-
std::span<scalar_value_t<T>> cdofs_b)
33+
std::span<scalar_value_t<T>> cdofs_b,
34+
void* custom_data = nullptr)
3435
{
3536
T value(0);
3637
if (cells.empty())
@@ -49,7 +50,7 @@ T assemble_cells(mdspan2_t x_dofmap,
4950
std::copy_n(&x(x_dofs[i], 0), 3, std::next(cdofs_b.begin(), 3 * i));
5051

5152
fn(&value, &coeffs(index, 0), constants.data(), cdofs_b.data(), nullptr,
52-
nullptr, nullptr);
53+
nullptr, custom_data);
5354
}
5455

5556
return value;
@@ -77,7 +78,7 @@ T assemble_entities(
7778
FEkernel<T> auto fn, std::span<const T> constants,
7879
md::mdspan<const T, md::dextents<std::size_t, 2>> coeffs,
7980
md::mdspan<const std::uint8_t, md::dextents<std::size_t, 2>> perms,
80-
std::span<scalar_value_t<T>> cdofs_b)
81+
std::span<scalar_value_t<T>> cdofs_b, void* custom_data = nullptr)
8182
{
8283
T value(0);
8384
if (entities.empty())
@@ -99,7 +100,7 @@ T assemble_entities(
99100
// Permutations
100101
std::uint8_t perm = perms.empty() ? 0 : perms(cell, local_entity);
101102
fn(&value, &coeffs(f, 0), constants.data(), cdofs_b.data(), &local_entity,
102-
&perm, nullptr);
103+
&perm, custom_data);
103104
}
104105

105106
return value;
@@ -120,7 +121,7 @@ T assemble_interior_facets(
120121
md::dynamic_extent>>
121122
coeffs,
122123
md::mdspan<const std::uint8_t, md::dextents<std::size_t, 2>> perms,
123-
std::span<scalar_value_t<T>> cdofs_b)
124+
std::span<scalar_value_t<T>> cdofs_b, void* custom_data = nullptr)
124125
{
125126
T value(0);
126127
if (facets.empty())
@@ -150,7 +151,7 @@ T assemble_interior_facets(
150151
: std::array{perms(cells[0], local_facet[0]),
151152
perms(cells[1], local_facet[1])};
152153
fn(&value, &coeffs(f, 0, 0), constants.data(), cdofs_b.data(),
153-
local_facet.data(), perm.data(), nullptr);
154+
local_facet.data(), perm.data(), custom_data);
154155
}
155156

156157
return value;
@@ -178,11 +179,12 @@ T assemble_scalar(
178179
auto fn = M.kernel(IntegralType::cell, i, 0);
179180
assert(fn);
180181
auto& [coeffs, cstride] = coefficients.at({IntegralType::cell, i});
182+
void* custom_data = M.custom_data(IntegralType::cell, i, 0);
181183
std::span<const std::int32_t> cells = M.domain(IntegralType::cell, i, 0);
182184
assert(cells.size() * cstride == coeffs.size());
183185
value += impl::assemble_cells(
184186
x_dofmap, x, cells, fn, constants,
185-
md::mdspan(coeffs.data(), cells.size(), cstride), cdofs_b);
187+
md::mdspan(coeffs.data(), cells.size(), cstride), cdofs_b, custom_data);
186188
}
187189

188190
mesh::CellType cell_type = mesh->topology()->cell_type();
@@ -204,6 +206,7 @@ T assemble_scalar(
204206
assert(fn);
205207
auto& [coeffs, cstride]
206208
= coefficients.at({IntegralType::interior_facet, i});
209+
void* custom_data = M.custom_data(IntegralType::interior_facet, i, 0);
207210
std::span facets = M.domain(IntegralType::interior_facet, i, 0);
208211

209212
constexpr std::size_t num_adjacent_cells = 2;
@@ -220,7 +223,7 @@ T assemble_scalar(
220223
md::mdspan<const T, md::extents<std::size_t, md::dynamic_extent, 2,
221224
md::dynamic_extent>>(
222225
coeffs.data(), facets.size() / shape1, 2, cstride),
223-
facet_perms, cdofs_b);
226+
facet_perms, cdofs_b, custom_data);
224227
}
225228

226229
for (auto itg_type : {fem::IntegralType::exterior_facet,
@@ -236,6 +239,7 @@ T assemble_scalar(
236239
auto fn = M.kernel(itg_type, i, 0);
237240
assert(fn);
238241
auto& [coeffs, cstride] = coefficients.at({itg_type, i});
242+
void* custom_data = M.custom_data(itg_type, i, 0);
239243

240244
std::span entities = M.domain(itg_type, i, 0);
241245

@@ -248,7 +252,7 @@ T assemble_scalar(
248252
entities.data(), entities.size() / 2, 2),
249253
fn, constants,
250254
md::mdspan(coeffs.data(), entities.size() / 2, cstride), perms,
251-
cdofs_b);
255+
cdofs_b, custom_data);
252256
}
253257
}
254258

0 commit comments

Comments
 (0)