Skip to content

Commit a28c83a

Browse files
committed
Remove repeated points from dual_basis
1 parent f6b2712 commit a28c83a

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

FIAT/check_format_variant.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def parse_quadrature_scheme(ref_el, degree, quad_scheme=None):
103103
if opt in supported_splits:
104104
splitting = supported_splits[opt]
105105
ref_el = splitting(ref_el)
106+
elif opt.startswith("KMV") and opt != "KMV":
107+
match = re.match(r"^KMV(?:\((\d+)\))?$", opt)
108+
degree, = match.groups()
109+
degree = int(degree)
110+
scheme = "KMV"
106111
else:
107112
scheme = opt
108113
return create_quadrature(ref_el, degree, scheme or "default")

FIAT/macro.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,24 @@ def __init__(self, ref_el, Q_ref, parent_facets=None):
406406
Q_cur = FacetQuadratureRule(ref_el, parent_dim, entity, Q_ref)
407407
pts.extend(Q_cur.pts)
408408
wts.extend(Q_cur.wts)
409+
410+
# Merge points on facets, if any
411+
rtol = 1E-12
412+
bary = ref_el.compute_barycentric_coordinates(pts)
413+
if numpy.isclose(bary, 0, rtol=rtol).any():
414+
iorder = numpy.lexsort(bary.T)
415+
iprev = iorder[0]
416+
unique_pts = [pts[iprev]]
417+
unique_wts = [wts[iprev]]
418+
for icur in iorder[1:]:
419+
if numpy.allclose(bary[icur], bary[iprev], rtol=rtol):
420+
unique_wts[-1] += wts[icur]
421+
else:
422+
unique_pts.append(pts[icur])
423+
unique_wts.append(wts[icur])
424+
iprev = icur
425+
pts = unique_pts
426+
wts = unique_wts
409427
pts = tuple(pts)
410428
wts = tuple(wts)
411429
super().__init__(ref_el, pts, wts)

FIAT/quadrature_schemes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535
import numpy
3636

37-
from FIAT.quadrature import (QuadratureRule, FacetQuadratureRule, make_quadrature,
37+
from FIAT.quadrature import (FacetQuadratureRule,
38+
GaussLobattoLegendreQuadratureLineRule,
39+
QuadratureRule, make_quadrature,
3840
make_tensor_product_quadrature, map_quadrature)
3941
from FIAT.reference_element import (HEXAHEDRON, QUADRILATERAL, TENSORPRODUCT,
4042
TETRAHEDRON, TRIANGLE, symmetric_simplex, ufc_simplex)
@@ -118,7 +120,10 @@ def _kmv_lump_scheme(ref_el, degree):
118120
"""Specialized quadrature schemes for P < 6 for KMV simplical elements."""
119121

120122
sd = ref_el.get_spatial_dimension()
121-
if sd not in {2, 3}:
123+
if sd == 1:
124+
num_points = degree + 1
125+
return GaussLobattoLegendreQuadratureLineRule(ref_el, num_points)
126+
elif sd > 3:
122127
raise ValueError("Dimension not supported")
123128

124129
T = ufc_simplex(sd)

finat/fiat_elements.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,18 @@ def _dual_basis(self):
213213
Qdense = np.zeros(Qshape, dtype=np.float64)
214214
for idx, value in Q.items():
215215
Qdense[idx] = value
216+
# Compress repeated points
217+
repeated_pts = tuple(map(tuple, np.round(allpts, decimals=12)))
218+
unique_pts = list(dict.fromkeys(repeated_pts))
219+
if len(unique_pts) < len(repeated_pts):
220+
Qrepeated = Qdense
221+
Qshape = (Qshape[0], len(unique_pts), *Qshape[2:])
222+
Qdense = np.zeros(Qshape, dtype=np.float64)
223+
for j, i in enumerate(map(unique_pts.index, repeated_pts)):
224+
Qdense[:, i, ...] += Qrepeated[:, j, ...]
225+
allpts = unique_pts
216226
Q = gem.Literal(Qdense)
227+
217228
return Q, np.asarray(allpts)
218229

219230
@property

0 commit comments

Comments
 (0)