Skip to content

Commit bc527ab

Browse files
committed
Add functions to assemble multiple moment fitting systems
1 parent a350b79 commit bc527ab

File tree

2 files changed

+179
-1
lines changed

2 files changed

+179
-1
lines changed

source/pbat/math/MomentFitting.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,19 @@ TEST_CASE("[math] MomentFitting")
8181
// Assert
8282
CHECK((w1.array() >= Scalar(0)).all());
8383
CHECK_LT(error.maxCoeff(), 1e-10);
84+
85+
SUBCASE("Can also solve global sparse linear system for quadrature weights")
86+
{
87+
auto [M, B, P] = math::ReferenceMomentFittingSystems<kOrder>(
88+
S1,
89+
X1.bottomRows(kDims),
90+
S2,
91+
X2.bottomRows(kDims),
92+
w2);
93+
CSRMatrix GM = math::BlockDiagonalReferenceMomentFittingSystem(M, B, P);
94+
CHECK_EQ(GM.rows(), 2*math::OrthonormalPolynomialBasis<kDims, kOrder>::kSize);
95+
CHECK_EQ(GM.cols(), 8);
96+
CHECK_EQ(GM.nonZeros(), 2 * math::OrthonormalPolynomialBasis<kDims, kOrder>::kSize * 4);
97+
}
8498
}
8599
}

source/pbat/math/MomentFitting.h

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <exception>
1313
#include <limits>
1414
#include <tbb/parallel_for.h>
15+
#include <tuple>
1516
#include <unsupported/Eigen/NNLS>
1617
#include <utility>
1718

@@ -188,7 +189,7 @@ std::pair<VectorX, VectorX> TransferQuadrature(
188189
Eigen::MatrixBase<TDerivedXi1> const& Xi1,
189190
Eigen::DenseBase<TDerivedS2> const& S2,
190191
Eigen::MatrixBase<TDerivedXi2> const& Xi2,
191-
Eigen::MatrixBase<TDerivedWg2> const& wi2,
192+
Eigen::DenseBase<TDerivedWg2> const& wi2,
192193
bool bEvaluateError = false,
193194
Index maxIterations = 10,
194195
Scalar precision = std::numeric_limits<Scalar>::epsilon())
@@ -278,6 +279,169 @@ std::pair<VectorX, VectorX> TransferQuadrature(
278279
return {wi1, error};
279280
}
280281

282+
/**
283+
* @brief
284+
*
285+
* @tparam Order
286+
* @tparam TDerivedS1
287+
* @tparam TDerivedX1
288+
* @tparam TDerivedS2
289+
* @tparam TDerivedX2
290+
* @tparam TDerivedW2
291+
* @param S1
292+
* @param X1
293+
* @param S2
294+
* @param X2
295+
* @param w2
296+
* @return std::tuple<MatrixX, MatrixX, IndexVectorX>
297+
*/
298+
template <
299+
int Order,
300+
class TDerivedS1,
301+
class TDerivedX1,
302+
class TDerivedS2,
303+
class TDerivedX2,
304+
class TDerivedW2>
305+
std::tuple<MatrixX /*P*/, MatrixX /*B*/, IndexVectorX /*prefix into columns of P*/>
306+
ReferenceMomentFittingSystems(
307+
Eigen::DenseBase<TDerivedS1> const& S1,
308+
Eigen::MatrixBase<TDerivedX1> const& X1,
309+
Eigen::DenseBase<TDerivedS2> const& S2,
310+
Eigen::MatrixBase<TDerivedX2> const& X2,
311+
Eigen::DenseBase<TDerivedW2> const& w2)
312+
{
313+
// Compute adjacency graph from simplices s to their quadrature points Xi
314+
using common::ArgSort;
315+
using common::Counts;
316+
using common::CumSum;
317+
using common::ToEigen;
318+
auto nsimplices =
319+
std::max(*std::max_element(S1.begin(), S1.end()), *std::max_element(S2.begin(), S2.end())) +
320+
1;
321+
IndexVectorX S1P = ToEigen(CumSum(Counts<Index>(S1.begin(), S1.end(), nsimplices)));
322+
IndexVectorX S2P = ToEigen(CumSum(Counts<Index>(S2.begin(), S2.end(), nsimplices)));
323+
IndexVectorX S1N =
324+
ToEigen(ArgSort<Index>(S1.size(), [&](auto si, auto sj) { return S1(si) < S1(sj); }));
325+
IndexVectorX S2N =
326+
ToEigen(ArgSort<Index>(S2.size(), [&](auto si, auto sj) { return S2(si) < S2(sj); }));
327+
// Assemble moment fitting matrices and their rhs
328+
auto fPolyRows = [](MatrixX const& Xg) {
329+
if (Xg.rows() == 1)
330+
return OrthonormalPolynomialBasis<1, Order>::kSize;
331+
if (Xg.rows() == 2)
332+
return OrthonormalPolynomialBasis<2, Order>::kSize;
333+
if (Xg.rows() == 3)
334+
return OrthonormalPolynomialBasis<3, Order>::kSize;
335+
throw std::invalid_argument(
336+
"Expected quadrature points in reference simplex space of dimensions (i.e. rows) 1,2 "
337+
"or 3.");
338+
};
339+
auto fAssembleSystem = [](MatrixX const& Xg1,
340+
MatrixX const& Xg2,
341+
VectorX const& wg2) -> std::pair<MatrixX, VectorX> {
342+
if (Xg1.rows() == 1)
343+
{
344+
OrthonormalPolynomialBasis<1, Order> P{};
345+
auto M = ReferenceMomentFittingMatrix(P, Xg1);
346+
auto b = Integrate(P, Xg2, wg2);
347+
return {M, b};
348+
}
349+
if (Xg1.rows() == 2)
350+
{
351+
OrthonormalPolynomialBasis<2, Order> P{};
352+
auto M = ReferenceMomentFittingMatrix(P, Xg1);
353+
auto b = Integrate(P, Xg2, wg2);
354+
return {M, b};
355+
}
356+
if (Xg1.rows() == 3)
357+
{
358+
OrthonormalPolynomialBasis<3, Order> P{};
359+
auto M = ReferenceMomentFittingMatrix(P, Xg1);
360+
auto b = Integrate(P, Xg2, wg2);
361+
return {M, b};
362+
}
363+
throw std::invalid_argument(
364+
"Expected quadrature points in reference simplex space of dimensions (i.e. rows) 1,2 "
365+
"or 3.");
366+
};
367+
auto nrows = fPolyRows(X1);
368+
MatrixX P(nrows, S1N.size());
369+
MatrixX B(nrows, nsimplices);
370+
B.setZero();
371+
tbb::parallel_for(Index(0), nsimplices, [&](Index s) {
372+
auto S1begin = S1P(s);
373+
auto S1end = S1P(s + 1);
374+
if (S1end > S1begin)
375+
{
376+
auto s1inds = S1N(Eigen::seq(S1begin, S1end - 1));
377+
MatrixX Xg1 = X1(Eigen::placeholders::all, s1inds);
378+
auto S2begin = S2P(s);
379+
auto S2end = S2P(s + 1);
380+
auto s2inds = S2N(Eigen::seq(S2begin, S2end - 1));
381+
MatrixX Xg2 = X2(Eigen::placeholders::all, s2inds);
382+
VectorX wg2 = w2(s2inds);
383+
auto [Ps, bs] = fAssembleSystem(Xg1, Xg2, wg2);
384+
P.block(0, S1begin, nrows, S1end - S1begin) = Ps;
385+
B.col(s) = bs;
386+
}
387+
});
388+
return std::make_tuple(P, B, S1P);
389+
}
390+
391+
/**
392+
* @brief
393+
*
394+
* @tparam TDerivedM
395+
* @tparam TDerivedB
396+
* @tparam TDerivedP
397+
* @param M
398+
* @param B
399+
* @param P
400+
* @return CSRMatrix The block diagonal row sparse matrix GM, whose diagonal blocks are the
401+
* individual reference moment fitting matrices in M, such that GM @ w = B.reshaped() is the global
402+
* sparse linear system to solve for quadrature weights w.
403+
*/
404+
template <class TDerivedM, class TDerivedB, class TDerivedP>
405+
CSRMatrix BlockDiagonalReferenceMomentFittingSystem(
406+
Eigen::MatrixBase<TDerivedM> const& M,
407+
Eigen::MatrixBase<TDerivedB> const& B,
408+
Eigen::DenseBase<TDerivedP> const& P)
409+
{
410+
auto const nblocks = P.size() - 1;
411+
auto const nblockrows = M.rows();
412+
auto const nrows = nblockrows * nblocks;
413+
auto const ncols = P(Eigen::placeholders::last);
414+
CSRMatrix GM(nrows, ncols);
415+
std::vector<Index> reserves(nrows);
416+
for (auto b = 0; b < nblocks; ++b)
417+
{
418+
auto begin = P(b);
419+
auto end = P(b + 1);
420+
auto const nblockcols = end - begin;
421+
auto const offset = b * nblockrows;
422+
for (auto i = 0; i < nblockrows; ++i)
423+
reserves[offset + i] = nblockcols;
424+
}
425+
GM.reserve(reserves);
426+
for (auto b = 0; b < nblocks; ++b)
427+
{
428+
auto begin = P(b);
429+
auto end = P(b + 1);
430+
auto const nblockcols = end - begin;
431+
auto Mb = M.block(0, begin, nblockrows, nblockcols);
432+
auto const roffset = b * nblockrows;
433+
auto const coffset = begin;
434+
for (auto i = 0; i < nblockrows; ++i)
435+
{
436+
for (auto j = 0; j < nblockcols; ++j)
437+
{
438+
GM.insert(roffset + i, coffset + j) = Mb(i, j);
439+
}
440+
}
441+
}
442+
return GM;
443+
}
444+
281445
} // namespace math
282446
} // namespace pbat
283447

0 commit comments

Comments
 (0)