Skip to content

Commit 8499723

Browse files
committed
Fix bugs in moment fitting when assembling multiple systems
1 parent 0d2b984 commit 8499723

File tree

6 files changed

+156
-93
lines changed

6 files changed

+156
-93
lines changed

bindings/pypbat/math/MomentFitting.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,10 @@ void BindMomentFitting(pybind11::module& m)
8989

9090
m.def(
9191
"block_diagonalize_moment_fitting",
92-
[](Eigen::Ref<MatrixX const> const& M,
93-
Eigen::Ref<MatrixX const> const& B,
94-
Eigen::Ref<IndexVectorX const> const& P) {
95-
return pbat::math::BlockDiagonalReferenceMomentFittingSystem(M, B, P);
92+
[](Eigen::Ref<MatrixX const> const& M, Eigen::Ref<IndexVectorX const> const& P) {
93+
return pbat::math::BlockDiagonalReferenceMomentFittingSystem(M, P);
9694
},
9795
pyb::arg("M"),
98-
pyb::arg("B"),
9996
pyb::arg("P"),
10097
"Assemble the block diagonal row sparse matrix GM, such that GM @ w = B.flatten(order='F') "
10198
"contains all the reference moment fitting systems in (M,B,P).");

python/examples/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ libigl==v2.5.1
55
polyscope==2.2.1
66
ilupp==1.0.2
77
ipctk==1.2.0
8-
networkx==3.3
8+
networkx==3.3
9+
qpsolvers[open_source_solvers]==4.4.0

python/polynomial/basis.py

Lines changed: 75 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def dot(u, v, x: list, a: list, b: list):
10-
assert(len(a) == len(b) == len(x))
10+
assert (len(a) == len(b) == len(x))
1111
I = u*v
1212
for i in reversed(range(len(a))):
1313
I = sp.integrate(I, (x[i], a[i], b[i])).expand()
@@ -44,19 +44,22 @@ def orthonormalized_basis(dims, order, x, a, b):
4444
V = []
4545
if dims == 1:
4646
# Generate the monomial basis
47-
Vl = sorted(itermonomials(x, order), key=monomial_key('grlex', list(reversed(x))))
47+
Vl = sorted(itermonomials(x, order),
48+
key=monomial_key('grlex', list(reversed(x))))
4849
# Orthonormalize the line's polynomial basis
4950
Vlo = gram_schmidt(Vl, x, a, b)
5051
V = normalize(Vlo, x, a, b)
5152
if dims == 2:
5253
# Generate the monomial basis
53-
Vf = sorted(itermonomials(x, order), key=monomial_key('grlex', list(reversed(x))))
54+
Vf = sorted(itermonomials(x, order),
55+
key=monomial_key('grlex', list(reversed(x))))
5456
# Orthonormalize the triangle's polynomial basis'
5557
Vfo = gram_schmidt(Vf, x, a, b)
5658
V = normalize(Vfo, x, a, b)
5759
if dims == 3:
5860
# Generate the monomial basis
59-
Vt = sorted(itermonomials(x, order), key=monomial_key('grlex', list(reversed(x))))
61+
Vt = sorted(itermonomials(x, order),
62+
key=monomial_key('grlex', list(reversed(x))))
6063
# Orthonormalize the tetrahedron's polynomial basis'
6164
Vto = gram_schmidt(Vt, x, a, b)
6265
V = normalize(Vto, x, a, b)
@@ -88,11 +91,11 @@ def divergence_free_basis(V: list, dims: int, x: list, a: list, b: list):
8891
H = sp.Matrix.zeros(len(V), dims*len(V))
8992
for n in range(len(G)):
9093
gn = G[n]
91-
assert(len(gn) == len(x))
94+
assert (len(gn) == len(x))
9295
div = divergence(gn, x)
9396
# Compute the coefficients of div(gn) w.r.t. orthonormal basis V
9497
for i in range(len(V)):
95-
H[i,n] = dot(div, V[i], x, a, b)
98+
H[i, n] = dot(div, V[i], x, a, b)
9699

97100
HN = H.nullspace()
98101
GG = sp.Matrix(G).transpose()
@@ -102,12 +105,13 @@ def divergence_free_basis(V: list, dims: int, x: list, a: list, b: list):
102105
F.append(sp.simplify(fi))
103106
for fi in F:
104107
div = sp.simplify(divergence(fi, x))
105-
assert(div == 0)
108+
assert (div == 0)
106109
return F
107110

111+
108112
def header(file):
109113
file.write(
110-
"""#ifndef PBAT_MATH_POLYNOMIAL_BASIS_H
114+
"""#ifndef PBAT_MATH_POLYNOMIAL_BASIS_H
111115
#define PBAT_MATH_POLYNOMIAL_BASIS_H
112116
113117
/**
@@ -137,16 +141,17 @@ class DivergenceFreePolynomialBasis;
137141
"""
138142
)
139143

144+
140145
def footer(file):
141146
file.write(
142-
"""
147+
"""
143148
} // namespace math
144149
} // namespace pbat
145150
146151
#endif // PBAT_MATH_POLYNOMIAL_BASIS_H
147152
"""
148153
)
149-
154+
150155

151156
def codegen(file, V: list, X: list, order: int, orthonormal=False):
152157
dimV = len(V)
@@ -157,10 +162,11 @@ def codegen(file, V: list, X: list, order: int, orthonormal=False):
157162
code = cg.codegen(V, lhs=sp.MatrixSymbol("P", len(V), 1))
158163
GV = V.jacobian(X)
159164
jaccode = cg.codegen(GV, lhs=sp.MatrixSymbol("G", *GV.shape))
160-
AV = sp.Matrix([[sp.integrate(V[i], X[d]) for i in range(len(V))] for d in range(len(X))])
165+
AV = sp.Matrix([[sp.integrate(V[i], X[d])
166+
for i in range(len(V))] for d in range(len(X))])
161167
antiderivscode = cg.codegen(AV, lhs=sp.MatrixSymbol("P", *AV.shape))
162168
file.write(
163-
f"""
169+
f"""
164170
template <>
165171
class {classname}<{nvariables}, {order}>
166172
{{
@@ -169,21 +175,21 @@ class {classname}<{nvariables}, {order}>
169175
inline static constexpr std::size_t kOrder = {order};
170176
inline static constexpr std::size_t kSize = {dimV};
171177
172-
[[maybe_unused]] Vector<kSize> eval([[maybe_unused]] Vector<kDims> const& X) const
178+
[[maybe_unused]] Vector<kSize> eval([[maybe_unused]] Vector<kDims> const& X) const
173179
{{
174180
Vector<kSize> P;
175181
{cg.tabulate(code, spaces=8)}
176182
return P;
177-
}}
178-
183+
}}
184+
179185
[[maybe_unused]] Matrix<kDims, kSize> derivatives([[maybe_unused]] Vector<kDims> const& X) const
180186
{{
181187
Matrix<kDims, kSize> Gm;
182188
Scalar* G = Gm.data();
183189
{cg.tabulate(jaccode, spaces=8)}
184190
return Gm;
185191
}}
186-
192+
187193
[[maybe_unused]] Matrix<kSize, kDims> antiderivatives([[maybe_unused]] Vector<kDims> const& X) const
188194
{{
189195
Matrix<kSize, kDims> Pm;
@@ -202,7 +208,7 @@ def div_free_codegen(file, F: list, X: list, order: int):
202208
F = sp.Matrix(F).transpose()
203209
code = cg.codegen(F, lhs=sp.MatrixSymbol("P", *F.shape))
204210
file.write(
205-
f"""
211+
f"""
206212
template <>
207213
class {classname}<{nvariables}, {order}>
208214
{{
@@ -211,7 +217,7 @@ class {classname}<{nvariables}, {order}>
211217
inline static constexpr std::size_t kOrder = {order};
212218
inline static constexpr std::size_t kSize = {dimF};
213219
214-
[[maybe_unused]] Matrix<kSize, kDims> eval([[maybe_unused]] Vector<kDims> const& X) const
220+
[[maybe_unused]] Matrix<kSize, kDims> eval([[maybe_unused]] Vector<kDims> const& X) const
215221
{{
216222
Matrix<kSize, kDims> Pm;
217223
Scalar* P = Pm.data();
@@ -221,6 +227,7 @@ class {classname}<{nvariables}, {order}>
221227
}};
222228
""")
223229

230+
224231
if __name__ == "__main__":
225232
X = sp.Matrix(sp.MatrixSymbol("X", 3, 1))
226233

@@ -230,7 +237,8 @@ class {classname}<{nvariables}, {order}>
230237
)
231238
parser.add_argument("-o", "--order", help="Maximum degree of the polynomial basis functions", type=int,
232239
dest="order", default=1)
233-
parser.add_argument("-d", "--dry-run", help="Prints computed polynomial basis' to stdout", type=bool, dest="dry_run")
240+
parser.add_argument("-d", "--dry-run", help="Prints computed polynomial basis' to stdout",
241+
action="store_true", dest="dry_run")
234242
args = parser.parse_args()
235243

236244
# polynomial order
@@ -261,7 +269,7 @@ class {classname}<{nvariables}, {order}>
261269
for v in Vfo:
262270
print(v)
263271

264-
Vto = orthonormalized_basis(2, order, xt, ta, tb)
272+
Vto = orthonormalized_basis(3, order, xt, ta, tb)
265273
print("Orthonormal Polynomial Basis of order={} on the reference tetrahedron".format(order))
266274
for v in Vto:
267275
print(v)
@@ -270,61 +278,70 @@ class {classname}<{nvariables}, {order}>
270278
header(file)
271279

272280
file.write("\n/**\n * Monomial basis in 1D\n */\n")
273-
for o in range(order):
274-
V = sorted(itermonomials(xl, o+1), key=monomial_key('lex', list(reversed(xl))))
275-
codegen(file, V, xl, o+1)
281+
for o in range(order+2):
282+
V = sorted(itermonomials(xl, o),
283+
key=monomial_key('lex', list(reversed(xl))))
284+
codegen(file, V, xl, o)
276285
file.write("\n/**\n * Monomial basis in 2D\n */\n")
277-
for o in range(order):
278-
V = sorted(itermonomials(xf, o+1), key=monomial_key('lex', list(reversed(xf))))
279-
codegen(file, V, xf, o+1)
286+
for o in range(order+2):
287+
V = sorted(itermonomials(xf, o),
288+
key=monomial_key('lex', list(reversed(xf))))
289+
codegen(file, V, xf, o)
280290
file.write("\n/**\n * Monomial basis in 3D\n */\n")
281-
for o in range(order):
282-
V = sorted(itermonomials(xt, o+1), key=monomial_key('lex', list(reversed(xt))))
283-
codegen(file, V, xt, o+1)
291+
for o in range(order+1):
292+
V = sorted(itermonomials(xt, o),
293+
key=monomial_key('lex', list(reversed(xt))))
294+
codegen(file, V, xt, o)
284295

285296
Vl = []
286297
dVl = []
287-
for o in range(order):
288-
V = orthonormalized_basis(1,o+1, xl, la, lb)
298+
for o in range(order+1):
299+
V = orthonormalized_basis(1, o, xl, la, lb)
289300
Vl.append(V)
290301
dVl.append(divergence_free_basis(V, 1, xl, la, lb))
291302

292303
Vf = []
293304
dVf = []
294-
for o in range(order):
295-
V = orthonormalized_basis(2,o+1, xf, fa, fb)
305+
for o in range(order+1):
306+
V = orthonormalized_basis(2, o, xf, fa, fb)
296307
Vf.append(V)
297308
dVf.append(divergence_free_basis(V, 2, xf, fa, fb))
298309

299310
Vt = []
300311
dVt = []
301-
for o in range(order):
302-
V = orthonormalized_basis(3,o+1, xt, ta, tb)
312+
for o in range(order+1):
313+
V = orthonormalized_basis(3, o, xt, ta, tb)
303314
Vt.append(V)
304315
dVt.append(divergence_free_basis(V, 3, xt, ta, tb))
305316

306-
file.write("\n/**\n * Orthonormalized polynomial basis on reference line\n */\n")
307-
for o in range(order):
308-
codegen(file, Vl[o], xl, o+1, orthonormal=True)
309-
310-
file.write("\n/**\n * Orthonormalized polynomial basis on reference triangle\n */\n")
311-
for o in range(order):
312-
codegen(file, Vf[o], xf, o+1, orthonormal=True)
313-
314-
file.write("\n/**\n * Orthonormalized polynomial basis on reference tetrahedron\n */\n")
315-
for o in range(order):
316-
codegen(file, Vt[o], xt, o+1, orthonormal=True)
317-
318-
file.write("\n/**\n * Divergence free polynomial basis on reference line\n */\n")
319-
for o in range(order):
320-
div_free_codegen(file, dVl[o], xl, o+1)
321-
322-
file.write("\n/**\n * Divergence free polynomial basis on reference triangle\n */\n")
323-
for o in range(order):
324-
div_free_codegen(file, dVf[o], xf, o+1)
325-
326-
file.write("\n/**\n * Divergence free polynomial basis on reference tetrahedron\n */\n")
327-
for o in range(order):
328-
div_free_codegen(file, dVt[o], xt, o+1)
317+
file.write(
318+
"\n/**\n * Orthonormalized polynomial basis on reference line\n */\n")
319+
for o in range(order+1):
320+
codegen(file, Vl[o], xl, o, orthonormal=True)
321+
322+
file.write(
323+
"\n/**\n * Orthonormalized polynomial basis on reference triangle\n */\n")
324+
for o in range(order+1):
325+
codegen(file, Vf[o], xf, o, orthonormal=True)
326+
327+
file.write(
328+
"\n/**\n * Orthonormalized polynomial basis on reference tetrahedron\n */\n")
329+
for o in range(order+1):
330+
codegen(file, Vt[o], xt, o, orthonormal=True)
331+
332+
file.write(
333+
"\n/**\n * Divergence free polynomial basis on reference line\n */\n")
334+
for o in range(order+1):
335+
div_free_codegen(file, dVl[o], xl, o)
336+
337+
file.write(
338+
"\n/**\n * Divergence free polynomial basis on reference triangle\n */\n")
339+
for o in range(order+1):
340+
div_free_codegen(file, dVf[o], xf, o)
341+
342+
file.write(
343+
"\n/**\n * Divergence free polynomial basis on reference tetrahedron\n */\n")
344+
for o in range(order+1):
345+
div_free_codegen(file, dVt[o], xt, o)
329346

330347
footer(file)

source/pbat/math/MomentFitting.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ TEST_CASE("[math] MomentFitting")
9090
S2,
9191
X2.bottomRows(kDims),
9292
w2);
93-
CSRMatrix GM = math::BlockDiagonalReferenceMomentFittingSystem(M, B, P);
94-
CHECK_EQ(GM.rows(), 2*math::OrthonormalPolynomialBasis<kDims, kOrder>::kSize);
93+
CSRMatrix GM = math::BlockDiagonalReferenceMomentFittingSystem(M, P);
94+
CHECK_EQ(GM.rows(), 2 * math::OrthonormalPolynomialBasis<kDims, kOrder>::kSize);
9595
CHECK_EQ(GM.cols(), 8);
9696
CHECK_EQ(GM.nonZeros(), 2 * math::OrthonormalPolynomialBasis<kDims, kOrder>::kSize * 4);
9797
}

0 commit comments

Comments
 (0)