Skip to content

Commit eb7833f

Browse files
committed
Implement quadrature transfer for fixed points
1 parent d29fb31 commit eb7833f

File tree

4 files changed

+194
-14
lines changed

4 files changed

+194
-14
lines changed

source/pbat/common/Indexing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace common {
1111
TEST_CASE("[common] Cumulative sums are computable from any integral range type")
1212
{
1313
std::array<Index, 3> v{5, 10, 15};
14-
auto const cs = cumsum(v);
14+
auto const cs = CumSum(v);
1515
CHECK_EQ(cs, std::vector<Index>{0, 5, 15, 30});
1616
}
1717

source/pbat/common/Indexing.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "Concepts.h"
55

6+
#include <concepts>
67
#include <numeric>
78
#include <pbat/Aliases.h>
89
#include <ranges>
@@ -12,7 +13,7 @@ namespace pbat {
1213
namespace common {
1314

1415
template <CContiguousIndexRange R>
15-
std::vector<Index> cumsum(R&& sizes)
16+
std::vector<Index> CumSum(R&& sizes)
1617
{
1718
namespace rng = std::ranges;
1819
std::vector<Index> cs{};
@@ -25,6 +26,15 @@ std::vector<Index> cumsum(R&& sizes)
2526
return cs;
2627
}
2728

29+
template <std::integral TIndex>
30+
std::vector<TIndex> Counts(auto begin, auto end, auto ncounts)
31+
{
32+
std::vector<TIndex> counts(ncounts, TIndex(0));
33+
for (auto it = begin; it != end; ++it)
34+
++counts[*it];
35+
return counts;
36+
}
37+
2838
} // namespace common
2939
} // namespace pbat
3040

source/pbat/math/MomentFitting.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,8 @@ void TestFixedQuadrature(Scalar precision)
1818
pbat::math::SymmetricSimplexPolynomialQuadratureRule<Dims, Order> Q{};
1919
auto Xg = pbat::common::ToEigen(Q.points).reshaped(Q.kDims + 1, Q.kPoints);
2020
auto wg = pbat::common::ToEigen(Q.weights);
21-
auto M = pbat::math::ReferenceMomentFittingMatrix(P, Q);
22-
auto b = pbat::Vector<decltype(P)::kSize>::Zero().eval();
23-
for (auto g = 0; g < Q.kPoints; ++g)
24-
{
25-
b += wg(g) * P.eval(Xg.col(g).segment<Dims>(1));
26-
}
27-
auto w = pbat::math::MomentFittedWeights(M, b, 20);
21+
auto w =
22+
pbat::math::TransferQuadrature(P, Xg.bottomRows(Q.kDims), Xg.bottomRows(Q.kDims), wg, 20);
2823
CHECK_LT((w - wg).squaredNorm(), precision);
2924
CHECK((w.array() >= Scalar(0)).all());
3025
};

source/pbat/math/MomentFitting.h

Lines changed: 180 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
#define PBAT_MATH_MOMENT_FITTING_H
33

44
#include "Concepts.h"
5+
#include "PolynomialBasis.h"
56
#include "pbat/Aliases.h"
7+
#include "pbat/common/ArgSort.h"
68
#include "pbat/common/Eigen.h"
9+
#include "pbat/common/Indexing.h"
710

11+
#include <algorithm>
12+
#include <exception>
813
#include <limits>
914
#include <tbb/parallel_for.h>
1015
#include <unsupported/Eigen/NNLS>
@@ -51,22 +56,31 @@ Matrix<TBasis::kSize, Eigen::Dynamic> ReferenceMomentFittingMatrix(TBasis const&
5156
return P;
5257
}
5358

54-
template <CPolynomialBasis TBasis, int Dims, int Order>
59+
template <CPolynomialBasis TBasis, class TDerivedXg>
5560
Matrix<TBasis::kSize, Eigen::Dynamic>
56-
ReferenceMomentFittingMatrix(TBasis const& Pb, DynamicQuadrature<Dims, Order> const& Q)
61+
ReferenceMomentFittingMatrix(TBasis const& Pb, Eigen::MatrixBase<TDerivedXg> const& Xg)
5762
{
58-
using QuadratureType = DynamicQuadrature<Dims, Order>;
63+
using QuadratureType = DynamicQuadrature<TBasis::kDims, TBasis::kOrder>;
5964
static_assert(
6065
TBasis::kDims == QuadratureType::kDims,
6166
"Dimensions of the quadrature rule and the polynomial basis must match, i.e. a k-D "
6267
"polynomial must be fit in a k-D integration domain.");
63-
Matrix<TBasis::kSize, Eigen::Dynamic> P(TBasis::kSize, Q.weights.size());
64-
auto Xg = common::ToEigen(Q.points).reshaped(QuadratureType::kDims, Q.weights.size());
68+
Matrix<TBasis::kSize, Eigen::Dynamic> P(TBasis::kSize, Xg.cols());
6569
for (auto g = 0u; g < Xg.cols(); ++g)
6670
P.col(g) = Pb.eval(Xg.col(g));
6771
return P;
6872
}
6973

74+
/**
75+
* @brief Computes non-negative quadrature weights wg by moment fitting.
76+
* @tparam TDerivedP
77+
* @tparam TDerivedB
78+
* @param P Moment fitting matrix
79+
* @param b Target integrated polynomials
80+
* @param maxIterations Maximum number of non-negative least-squares active set solver
81+
* @param precision Convergence threshold
82+
* @return
83+
*/
7084
template <class TDerivedP, class TDerivedB>
7185
VectorX MomentFittedWeights(
7286
Eigen::MatrixBase<TDerivedP> const& P,
@@ -83,6 +97,167 @@ VectorX MomentFittedWeights(
8397
return w;
8498
}
8599

100+
/**
101+
* @brief
102+
* @tparam Polynomial
103+
* @tparam TDerivedXg
104+
* @tparam TDerivedWg
105+
* @param P
106+
* @param Xg |#dims|x|#quad.pts.|
107+
* @param wg |#quad.pts.|x1
108+
* @return
109+
*/
110+
template <CPolynomialBasis Polynomial, class TDerivedXg, class TDerivedWg>
111+
Vector<Polynomial::kSize> Integrate(
112+
Polynomial const& P,
113+
Eigen::MatrixBase<TDerivedXg> const& Xg,
114+
Eigen::DenseBase<TDerivedWg> const& wg)
115+
{
116+
Vector<Polynomial::kSize> b{};
117+
b.setZero();
118+
for (auto g = 0; g < wg.size(); ++g)
119+
b += wg(g) * P.eval(Xg.col(g));
120+
return b;
121+
}
122+
123+
/**
124+
* @brief
125+
* @tparam TDerivedXg1
126+
* @tparam TDerivedXg2
127+
* @tparam TDerivedWg2
128+
* @tparam Polynomial
129+
* @param P
130+
* @param Xg1 |#dims|x|#new quad.pts.|
131+
* @param Xg2 |#dims|x|#old quad.pts.|
132+
* @param wg2 |#old quad.pts.|x1
133+
* @param maxIterations Maximum number of non-negative least-squares active set solver
134+
* @param precision Convergence threshold
135+
* @return
136+
*/
137+
template <CPolynomialBasis Polynomial, class TDerivedXg1, class TDerivedXg2, class TDerivedWg2>
138+
Vector<TDerivedXg1::ColsAtCompileTime> TransferQuadrature(
139+
Polynomial const& P,
140+
Eigen::MatrixBase<TDerivedXg1> const& Xg1,
141+
Eigen::MatrixBase<TDerivedXg2> const& Xg2,
142+
Eigen::DenseBase<TDerivedWg2> const& wg2,
143+
Index maxIterations = 10,
144+
Scalar precision = std::numeric_limits<Scalar>::epsilon())
145+
{
146+
auto b = Integrate(P, Xg2, wg2);
147+
auto M = ReferenceMomentFittingMatrix(P, Xg1);
148+
auto w = MomentFittedWeights(M, b, maxIterations, precision);
149+
return w;
150+
}
151+
152+
/**
153+
* @brief Obtain weights wi1 by transferring an existing quadrature rule (Xi2,wi2) defined on a
154+
* domain composed of simplices onto a new quadrature rule (Xi1,wi1) defined on the same domain,
155+
* given fixed quadrature points Xi1.
156+
*
157+
* @tparam Order Order of the quadrature rules
158+
* @tparam TDerivedS1
159+
* @tparam TDerivedXi1
160+
* @tparam TDerivedS2
161+
* @tparam TDerivedXi2
162+
* @tparam TDerivedWg2
163+
* @param S1 1x|Xi1.cols()| Index array giving the simplex containing the corresponding quadrature
164+
* point in columns of Xi1.
165+
* @param Xi1 |#dims|x|#new quad.pts.| array of quadrature point positions defined in reference
166+
* simplex space
167+
* @param S2 |Xi2.cols()|x1 index array giving the simplex containing the corresponding quadrature
168+
* point in columns of Xi2.
169+
* @param Xi2 |#dims|x|#old quad.pts.| array of quadrature point positions defined in reference
170+
* simplex space
171+
* @param wi2 |#old quad.pts.|x1 array of quadrature weights
172+
* @param maxIterations Maximum number of non-negative least-squares active set solver
173+
* @param precision Convergence threshold
174+
* @return
175+
*/
176+
template <
177+
auto Order,
178+
class TDerivedS1,
179+
class TDerivedXi1,
180+
class TDerivedS2,
181+
class TDerivedXi2,
182+
class TDerivedWg2>
183+
VectorX TransferQuadrature(
184+
Eigen::DenseBase<TDerivedS1> const& S1,
185+
Eigen::MatrixBase<TDerivedXi1> const& Xi1,
186+
Eigen::DenseBase<TDerivedS2> const& S2,
187+
Eigen::MatrixBase<TDerivedXi2> const& Xi2,
188+
Eigen::MatrixBase<TDerivedWg2> const& wi2,
189+
Index maxIterations = 10,
190+
Scalar precision = std::numeric_limits<Scalar>::epsilon())
191+
{
192+
// Compute adjacency graph from simplices s to their quadrature points Xi
193+
auto nsimplices =
194+
std::max(*std::max_element(S1.begin(), S1.end()), *std::max_element(S2.begin(), S2.end())) +
195+
1;
196+
std::vector<Index> S1P =
197+
common::CumSum(common::Counts<Index>(S1.begin(), S1.end(), nsimplices));
198+
std::vector<Index> S2P =
199+
common::CumSum(common::Counts<Index>(S2.begin(), S2.end(), nsimplices));
200+
std::vector<Index> S1N =
201+
common::ArgSort<Index>(S1.size(), [](auto si, auto sj) { return S1(si) < S1(sj); });
202+
std::vector<Index> S2N =
203+
common::ArgSort<Index>(S2.size(), [](auto si, auto sj) { return S2(si) < S2(sj); });
204+
// Find weights wg1 that fit the given quadrature rule Xi2, wi2 on simplices S2
205+
auto fSolveWeights = [maxIterations,
206+
precision](MatrixX const& Xg1, MatrixX const& Xg2, VectorX const& wg2) {
207+
if (Xg1.rows() == 1)
208+
{
209+
return TransferQuadrature(
210+
OrthonormalPolynomialBasis<1, Order>{},
211+
Xg1,
212+
Xg2,
213+
wg2,
214+
maxIterations,
215+
precision);
216+
}
217+
if (Xg1.rows() == 2)
218+
{
219+
return TransferQuadrature(
220+
OrthonormalPolynomialBasis<2, Order>{},
221+
Xg1,
222+
Xg2,
223+
wg2,
224+
maxIterations,
225+
precision);
226+
}
227+
if (Xg1.rows() == 3)
228+
{
229+
return TransferQuadrature(
230+
OrthonormalPolynomialBasis<3, Order>{},
231+
Xg1,
232+
Xg2,
233+
wg2,
234+
maxIterations,
235+
precision);
236+
}
237+
throw std::invalid_argument(
238+
"Expected quadrature points in reference simplex space of dimensions (i.e. rows) 1,2 "
239+
"or 3.");
240+
};
241+
242+
VectorX wg1 = VectorX::Zero(Xi1.cols());
243+
tbb::parallel_for(Index(0), nsimplices, [&](Index s) {
244+
auto S1begin = S1P[s];
245+
auto S1end = S1P[s + 1];
246+
if (S1end > S1begin)
247+
{
248+
auto s1inds = S1N(Eigen::seq(S1begin, S1end - 1));
249+
MatrixX Xg1 = Xi1(Eigen::placeholders::all, s1inds);
250+
auto S2begin = S2P[s];
251+
auto S2end = S2P[s + 1];
252+
auto s2inds = S2N(Eigen::seq(S2begin, S2end - 1));
253+
MatrixX Xg2 = Xi2(Eigen::placeholders::all, s2inds);
254+
VectorX wg2 = wi2(s2inds);
255+
wg1(s1inds) = fSolveWeights(Xg1, Xg2, wg2);
256+
}
257+
});
258+
return wg1;
259+
}
260+
86261
} // namespace math
87262
} // namespace pbat
88263

0 commit comments

Comments
 (0)