Skip to content

Commit f3211e7

Browse files
authored
Merge pull request #201 from firedrakeproject/JHopeCollins/mixed-broken-element
2 parents 9968afb + 230b5e2 commit f3211e7

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

finat/ufl/brokenelement.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,30 @@
1010
# Modified by Matthew Scroggs, 2023
1111

1212
from finat.ufl.finiteelementbase import FiniteElementBase
13+
from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement
1314
from ufl.sobolevspace import L2
1415

1516

1617
class BrokenElement(FiniteElementBase):
1718
"""The discontinuous version of an existing Finite Element space."""
19+
def __new__(cls, element):
20+
"""
21+
Broken qualifier must be below Mixed/Vector/Tensor so we
22+
overload __new__ to return:
23+
24+
BrokenElement(MixedElement(elem0, elem1)) -> MixedElement(BrokenElement(elem0), BrokenElement(elem1))
25+
26+
and similarly for VectorElement and TensorElement.
27+
"""
28+
if isinstance(element, (VectorElement, TensorElement)):
29+
return element.reconstruct(sub_element=BrokenElement(element.sub_elements[0]))
30+
31+
elif isinstance(element, MixedElement):
32+
return MixedElement(list(map(BrokenElement, element.sub_elements)))
33+
34+
else: # hopefully no special casing needed
35+
return super().__new__(cls)
36+
1837
def __init__(self, element):
1938
"""Init."""
2039
self._element = element

finat/ufl/mixedelement.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,10 @@ def __repr__(self):
345345
"""Doc."""
346346
return self._repr
347347

348-
def reconstruct(self, **kwargs):
348+
def reconstruct(self, sub_element=None, **kwargs):
349349
"""Doc."""
350-
sub_element = self._sub_element.reconstruct(**kwargs)
350+
if sub_element is None:
351+
sub_element = self._sub_element.reconstruct(**kwargs)
351352
return VectorElement(sub_element, dim=len(self.sub_elements))
352353

353354
def variant(self):
@@ -544,9 +545,10 @@ def symmetry(self):
544545
"""
545546
return self._symmetry
546547

547-
def reconstruct(self, **kwargs):
548+
def reconstruct(self, sub_element=None, **kwargs):
548549
"""Doc."""
549-
sub_element = self._sub_element.reconstruct(**kwargs)
550+
if sub_element is None:
551+
sub_element = self._sub_element.reconstruct(**kwargs)
550552
return TensorElement(sub_element, shape=self._shape, symmetry=self._symmetry)
551553

552554
def __str__(self):
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
import ufl
3+
from finat.ufl import FiniteElement, BrokenElement, VectorElement, TensorElement, MixedElement
4+
5+
sub_elements = [
6+
FiniteElement("CG", ufl.triangle, 1),
7+
FiniteElement("BDM", ufl.triangle, 2),
8+
FiniteElement("DG", ufl.interval, 2, variant="spectral")
9+
]
10+
11+
sub_ids = [
12+
"CG(1)",
13+
"BDM(2)",
14+
"DG(2,spectral)"
15+
]
16+
17+
18+
@pytest.mark.parametrize("sub_element", sub_elements, ids=sub_ids)
19+
@pytest.mark.parametrize("shape", (1, 2, (2, 3)), ids=("1", "2", "(2,3)"))
20+
def test_create_broken_vector_or_tensor_element(shape, sub_element):
21+
"""Check that BrokenElement returns a nested element
22+
for mixed, vector, and tensor elements.
23+
"""
24+
if not isinstance(shape, int):
25+
make_element = lambda elem: TensorElement(elem, shape=shape)
26+
else:
27+
make_element = lambda elem: VectorElement(elem, dim=shape)
28+
29+
tensor = make_element(sub_element)
30+
expected = make_element(BrokenElement(sub_element))
31+
32+
assert BrokenElement(tensor) == expected
33+
34+
35+
@pytest.mark.parametrize("sub_elements", [sub_elements, sub_elements[-1:]],
36+
ids=(f"nsubs={len(sub_elements)}", "nsubs=1"))
37+
def test_create_broken_mixed_element(sub_elements):
38+
"""Check that BrokenElement returns a nested element
39+
for mixed, vector, and tensor elements.
40+
"""
41+
mixed = MixedElement(sub_elements)
42+
expected = MixedElement([BrokenElement(elem) for elem in sub_elements])
43+
assert BrokenElement(mixed) == expected
44+
45+
46+
if __name__ == "__main__":
47+
import os
48+
import sys
49+
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])

0 commit comments

Comments
 (0)