@@ -60,6 +60,7 @@ using mdspan2_t = md::mdspan<const std::int32_t, md::dextents<std::size_t, 2>>;
6060// / function mesh.
6161// / @param cell_info1 Cell permutation information for the trial
6262// / function mesh.
63+ // / @param custom_data Custom user data pointer passed to the kernel.
6364template <dolfinx::scalar T>
6465void assemble_cells_matrix (
6566 la::MatSet<T> auto mat_set, mdspan2_t x_dofmap,
@@ -74,7 +75,7 @@ void assemble_cells_matrix(
7475 std::span<const std::int8_t > bc1, FEkernel<T> auto kernel,
7576 md::mdspan<const T, md::dextents<std::size_t , 2 >> coeffs,
7677 std::span<const T> constants, std::span<const std::uint32_t > cell_info0,
77- std::span<const std::uint32_t > cell_info1)
78+ std::span<const std::uint32_t > cell_info1, void * custom_data = nullptr )
7879{
7980 if (cells.empty ())
8081 return ;
@@ -109,7 +110,7 @@ void assemble_cells_matrix(
109110 // Tabulate tensor
110111 std::ranges::fill (Ae, 0 );
111112 kernel (Ae.data (), &coeffs (c, 0 ), constants.data (), cdofs.data (), nullptr ,
112- nullptr , nullptr );
113+ nullptr , custom_data );
113114
114115 // Compute A = P_0 \tilde{A} P_1^T (dof transformation)
115116 P0 (Ae, cell_info0, cell0, ndim1); // B = P0 \tilde{A}
@@ -198,6 +199,7 @@ void assemble_cells_matrix(
198199// / function mesh.
199200// / @param[in] perms Entity permutation integer. Empty if entity
200201// / permutations are not required.
202+ // / @param custom_data Custom user data pointer passed to the kernel.
201203template <dolfinx::scalar T>
202204void assemble_entities (
203205 la::MatSet<T> auto mat_set, mdspan2_t x_dofmap,
@@ -221,7 +223,8 @@ void assemble_entities(
221223 md::mdspan<const T, md::dextents<std::size_t , 2 >> coeffs,
222224 std::span<const T> constants, std::span<const std::uint32_t > cell_info0,
223225 std::span<const std::uint32_t > cell_info1,
224- md::mdspan<const std::uint8_t , md::dextents<std::size_t , 2 >> perms)
226+ md::mdspan<const std::uint8_t , md::dextents<std::size_t , 2 >> perms,
227+ void * custom_data = nullptr )
225228{
226229 if (entities.empty ())
227230 return ;
@@ -259,7 +262,7 @@ void assemble_entities(
259262 // Tabulate tensor
260263 std::ranges::fill (Ae, 0 );
261264 kernel (Ae.data (), &coeffs (f, 0 ), constants.data (), cdofs.data (),
262- &local_entity, &perm, nullptr );
265+ &local_entity, &perm, custom_data );
263266 P0 (Ae, cell_info0, cell0, ndim1);
264267 P1T (Ae, cell_info1, cell1, ndim0);
265268
@@ -363,7 +366,8 @@ void assemble_interior_facets(
363366 coeffs,
364367 std::span<const T> constants, std::span<const std::uint32_t > cell_info0,
365368 std::span<const std::uint32_t > cell_info1,
366- md::mdspan<const std::uint8_t , md::dextents<std::size_t , 2 >> perms)
369+ md::mdspan<const std::uint8_t , md::dextents<std::size_t , 2 >> perms,
370+ void * custom_data = nullptr )
367371{
368372 if (facets.empty ())
369373 return ;
@@ -440,7 +444,7 @@ void assemble_interior_facets(
440444 : std::array{perms (cells[0 ], local_facet[0 ]),
441445 perms (cells[1 ], local_facet[1 ])};
442446 kernel (Ae.data (), &coeffs (f, 0 , 0 ), constants.data (), cdofs.data (),
443- local_facet.data (), perm.data (), nullptr );
447+ local_facet.data (), perm.data (), custom_data );
444448
445449 // Local element layout is a 2x2 block matrix with structure
446450 //
@@ -605,12 +609,13 @@ void assemble_matrix(
605609 std::span cells0 = a.domain_arg (IntegralType::cell, 0 , i, cell_type_idx);
606610 std::span cells1 = a.domain_arg (IntegralType::cell, 1 , i, cell_type_idx);
607611 auto & [coeffs, cstride] = coefficients.at ({IntegralType::cell, i});
612+ void * custom_data = a.custom_data (IntegralType::cell, i, cell_type_idx);
608613 assert (cells.size () * cstride == coeffs.size ());
609614 impl::assemble_cells_matrix (
610615 mat_set, x_dofmap, x, cells, {dofs0, bs0, cells0}, P0,
611616 {dofs1, bs1, cells1}, P1T, bc0, bc1, fn,
612617 md::mdspan (coeffs.data (), cells.size (), cstride), constants,
613- cell_info0, cell_info1);
618+ cell_info0, cell_info1, custom_data );
614619 }
615620
616621 md::mdspan<const std::uint8_t , md::dextents<std::size_t , 2 >> facet_perms;
@@ -646,6 +651,7 @@ void assemble_matrix(
646651 assert (fn);
647652 auto & [coeffs, cstride]
648653 = coefficients.at ({IntegralType::interior_facet, i});
654+ void * custom_data = a.custom_data (IntegralType::interior_facet, i, 0 );
649655
650656 std::span facets = a.domain (IntegralType::interior_facet, i, 0 );
651657 std::span facets0 = a.domain_arg (IntegralType::interior_facet, 0 , i, 0 );
@@ -661,7 +667,7 @@ void assemble_matrix(
661667 mdspanx22_t (facets1.data (), facets1.size () / 4 , 2 , 2 )},
662668 P1T, bc0, bc1, fn,
663669 mdspanx2x_t (coeffs.data (), facets.size () / 4 , 2 , cstride), constants,
664- cell_info0, cell_info1, facet_perms);
670+ cell_info0, cell_info1, facet_perms, custom_data );
665671 }
666672
667673 for (auto itg_type : {fem::IntegralType::exterior_facet,
@@ -688,6 +694,7 @@ void assemble_matrix(
688694 auto fn = a.kernel (itg_type, i, 0 );
689695 assert (fn);
690696 auto & [coeffs, cstride] = coefficients.at ({itg_type, i});
697+ void * custom_data = a.custom_data (itg_type, i, 0 );
691698
692699 std::span e = a.domain (itg_type, i, 0 );
693700 mdspanx2_t entities (e.data (), e.size () / 2 , 2 );
@@ -700,7 +707,7 @@ void assemble_matrix(
700707 mat_set, x_dofmap, x, entities, {dofs0, bs0, entities0}, P0,
701708 {dofs1, bs1, entities1}, P1T, bc0, bc1, fn,
702709 md::mdspan (coeffs.data (), entities.extent (0 ), cstride), constants,
703- cell_info0, cell_info1, perms);
710+ cell_info0, cell_info1, perms, custom_data );
704711 }
705712 }
706713 }
0 commit comments