Skip to content

Commit d29fb31

Browse files
committed
Implement moment fitting via non-negative least-squares
1 parent 1fe982e commit d29fb31

File tree

6 files changed

+153
-16
lines changed

6 files changed

+153
-16
lines changed

source/pbat/math/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ target_sources(PhysicsBasedAnimationToolkit_PhysicsBasedAnimationToolkit
77
"IntegerArithmeticChecks.h"
88
"LinearOperator.h"
99
"Math.h"
10+
"MomentFitting.h"
1011
"PolynomialBasis.h"
1112
"Rational.h"
1213
"SymmetricQuadratureRules.h"
@@ -16,6 +17,7 @@ target_sources(PhysicsBasedAnimationToolkit_PhysicsBasedAnimationToolkit
1617
"GaussQuadrature.cpp"
1718
"IntegerArithmeticChecks.cpp"
1819
"LinearOperator.cpp"
20+
"MomentFitting.cpp"
1921
"PolynomialBasis.cpp"
2022
"Rational.cpp"
2123
"SymmetricQuadratureRules.cpp"

source/pbat/math/Concepts.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ concept CPolynomialQuadratureRule = requires(Q q)
6666
requires std::is_integral_v<decltype(Q::kOrder)>;
6767
};
6868

69+
template <class Q>
70+
concept CFixedPointPolynomialQuadratureRule =
71+
CFixedPointQuadratureRule<Q> and CPolynomialQuadratureRule<Q>;
72+
6973
} // namespace math
7074
} // namespace pbat
7175

source/pbat/math/Math.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "GaussQuadrature.h"
66
#include "IntegerArithmeticChecks.h"
77
#include "LinearOperator.h"
8+
#include "MomentFitting.h"
89
#include "PolynomialBasis.h"
910
#include "Rational.h"
1011
#include "SymmetricQuadratureRules.h"

source/pbat/math/MomentFitting.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "MomentFitting.h"
2+
3+
#include "PolynomialBasis.h"
4+
#include "SymmetricQuadratureRules.h"
5+
#include "pbat/common/ConstexprFor.h"
6+
#include "pbat/common/Eigen.h"
7+
8+
#include <doctest/doctest.h>
9+
10+
namespace pbat {
11+
namespace math {
12+
namespace test {
13+
14+
template <auto Dims, auto Order>
15+
void TestFixedQuadrature(Scalar precision)
16+
{
17+
pbat::math::OrthonormalPolynomialBasis<Dims, Order> P{};
18+
pbat::math::SymmetricSimplexPolynomialQuadratureRule<Dims, Order> Q{};
19+
auto Xg = pbat::common::ToEigen(Q.points).reshaped(Q.kDims + 1, Q.kPoints);
20+
auto wg = pbat::common::ToEigen(Q.weights);
21+
auto M = pbat::math::ReferenceMomentFittingMatrix(P, Q);
22+
auto b = pbat::Vector<decltype(P)::kSize>::Zero().eval();
23+
for (auto g = 0; g < Q.kPoints; ++g)
24+
{
25+
b += wg(g) * P.eval(Xg.col(g).segment<Dims>(1));
26+
}
27+
auto w = pbat::math::MomentFittedWeights(M, b, 20);
28+
CHECK_LT((w - wg).squaredNorm(), precision);
29+
CHECK((w.array() >= Scalar(0)).all());
30+
};
31+
32+
} // namespace test
33+
} // namespace math
34+
} // namespace pbat
35+
36+
TEST_CASE("[math] MomentFitting")
37+
{
38+
using namespace pbat;
39+
SUBCASE("Moment fitting reproduces fixed quadrature rule")
40+
{
41+
Scalar constexpr precision(1e-10);
42+
math::test::TestFixedQuadrature<1, 1>(precision);
43+
math::test::TestFixedQuadrature<1, 3>(precision);
44+
45+
math::test::TestFixedQuadrature<2, 1>(precision);
46+
math::test::TestFixedQuadrature<2, 2>(precision);
47+
math::test::TestFixedQuadrature<2, 3>(precision);
48+
math::test::TestFixedQuadrature<2, 4>(precision);
49+
50+
math::test::TestFixedQuadrature<3, 1>(precision);
51+
math::test::TestFixedQuadrature<3, 2>(precision);
52+
math::test::TestFixedQuadrature<3, 3>(precision);
53+
math::test::TestFixedQuadrature<3, 4>(precision);
54+
}
55+
}

source/pbat/math/MomentFitting.h

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#ifndef PBAT_MATH_MOMENT_FITTING_H
2+
#define PBAT_MATH_MOMENT_FITTING_H
3+
4+
#include "Concepts.h"
5+
#include "pbat/Aliases.h"
6+
#include "pbat/common/Eigen.h"
7+
8+
#include <limits>
9+
#include <tbb/parallel_for.h>
10+
#include <unsupported/Eigen/NNLS>
11+
12+
namespace pbat {
13+
namespace math {
14+
15+
template <int Dims, int Order>
16+
struct DynamicQuadrature
17+
{
18+
static auto constexpr kOrder = Order;
19+
static auto constexpr kDims = Dims;
20+
21+
MatrixX points; ///< |kDims| x |#weights| array of quadrature points Xg
22+
VectorX weights; ///< Array of quadrature weights wg associated with Xg
23+
};
24+
25+
template <CPolynomialBasis TBasis, CFixedPointPolynomialQuadratureRule TQuad>
26+
Matrix<TBasis::kSize, TQuad::kPoints> ReferenceMomentFittingMatrix(TBasis const& Pb, TQuad const& Q)
27+
{
28+
static_assert(
29+
TBasis::kDims == TQuad::kDims,
30+
"Dimensions of the quadrature rule and the polynomial basis must match, i.e. a k-D "
31+
"polynomial must be fit in a k-D integration domain.");
32+
Matrix<TBasis::kSize, TQuad::kPoints> P{};
33+
auto Xg = common::ToEigen(Q.points).reshaped(TQuad::kDims + 1, TQuad::kPoints);
34+
// Eigen::Map<Matrix<TQuad::kDims + 1, TQuad::kPoints> const> Xg(Q.points.data());
35+
for (auto g = 0u; g < TQuad::kPoints; ++g)
36+
P.col(g) = Pb.eval(Xg.col(g).template segment<TQuad::kDims>(1));
37+
return P;
38+
}
39+
40+
template <CPolynomialBasis TBasis, CPolynomialQuadratureRule TQuad>
41+
Matrix<TBasis::kSize, Eigen::Dynamic> ReferenceMomentFittingMatrix(TBasis const& Pb, TQuad const& Q)
42+
{
43+
static_assert(
44+
TBasis::kDims == TQuad::kDims,
45+
"Dimensions of the quadrature rule and the polynomial basis must match, i.e. a k-D "
46+
"polynomial must be fit in a k-D integration domain.");
47+
Matrix<TBasis::kSize, Eigen::Dynamic> P(TBasis::kSize, Q.weights.size());
48+
auto Xg = common::ToEigen(Q.points).reshaped(TQuad::kDims + 1, Q.weights.size());
49+
for (auto g = 0u; g < Xg.cols(); ++g)
50+
P.col(g) = Pb.eval(Xg.col(g).template segment<TQuad::kDims>(1));
51+
return P;
52+
}
53+
54+
template <CPolynomialBasis TBasis, int Dims, int Order>
55+
Matrix<TBasis::kSize, Eigen::Dynamic>
56+
ReferenceMomentFittingMatrix(TBasis const& Pb, DynamicQuadrature<Dims, Order> const& Q)
57+
{
58+
using QuadratureType = DynamicQuadrature<Dims, Order>;
59+
static_assert(
60+
TBasis::kDims == QuadratureType::kDims,
61+
"Dimensions of the quadrature rule and the polynomial basis must match, i.e. a k-D "
62+
"polynomial must be fit in a k-D integration domain.");
63+
Matrix<TBasis::kSize, Eigen::Dynamic> P(TBasis::kSize, Q.weights.size());
64+
auto Xg = common::ToEigen(Q.points).reshaped(QuadratureType::kDims, Q.weights.size());
65+
for (auto g = 0u; g < Xg.cols(); ++g)
66+
P.col(g) = Pb.eval(Xg.col(g));
67+
return P;
68+
}
69+
70+
template <class TDerivedP, class TDerivedB>
71+
VectorX MomentFittedWeights(
72+
Eigen::MatrixBase<TDerivedP> const& P,
73+
Eigen::DenseBase<TDerivedB> const& b,
74+
Index maxIterations = 10,
75+
Scalar precision = std::numeric_limits<Scalar>::epsilon())
76+
{
77+
using MatrixType = TDerivedP;
78+
Eigen::NNLS<MatrixType> nnls{};
79+
nnls.compute(P.derived());
80+
nnls.setMaxIterations(maxIterations);
81+
nnls.setTolerance(precision);
82+
auto w = nnls.solve(b);
83+
return w;
84+
}
85+
86+
} // namespace math
87+
} // namespace pbat
88+
89+
#endif // PBAT_MATH_MOMENT_FITTING_H

source/pbat/math/SymmetricQuadratureRules.h

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,18 @@
88
namespace pbat {
99
namespace math {
1010

11-
template <class TBasis, class Quad>
12-
Matrix<TBasis::kSize, Quad::kPoints> ReferenceMomentFittingMatrix(TBasis const& Pb, Quad const& Q)
13-
{
14-
static_assert(
15-
TBasis::kDims == Quad::kDims,
16-
"Dimensions of the quadrature rule and the polynomial basis must match, i.e. a 2D "
17-
"polynomial must be fit in a 2D integration domain");
18-
Matrix<TBasis::kSize, Quad::kPoints> P{};
19-
Eigen::Map<Matrix<Quad::kDims + 1, Quad::kPoints> const> Xg(Q.points.data());
20-
for (auto g = 0u; g < Quad::kPoints; ++g)
21-
P.col(g) = Pb.eval(Xg.col(g).template segment<Quad::kDims>(1));
22-
return P;
23-
}
24-
2511
/**
2612
* @brief Represents a quadrature scheme that can be constructed via existing quadrature schemes.
2713
* However, this generic quadrature scheme can be modified, i.e. its points and weights are instance
2814
* member variables.
2915
*/
3016
template <class Quad>
31-
struct ModifiableQuadratureScheme
17+
struct FixedSizeVariableQuadrature
3218
{
3319
inline static std::uint8_t constexpr kDims = Quad::kDims;
3420
inline static std::uint16_t constexpr kPoints = Quad::kPoints;
3521
inline static std::uint8_t constexpr kOrder = Quad::kOrder;
36-
ModifiableQuadratureScheme() : points(Quad::points), weights(Quad::weights) {}
22+
FixedSizeVariableQuadrature() : points(Quad::points), weights(Quad::weights) {}
3723

3824
decltype(Quad::points) points;
3925
decltype(Quad::weights) weights;

0 commit comments

Comments
 (0)