Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 46 additions & 19 deletions finat/ufl/mixedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import numpy as np

from ufl.cell import as_cell
from ufl.cell import CellSequence, as_cell
from ufl.domain import MeshSequence
from finat.ufl.finiteelement import FiniteElement
from finat.ufl.finiteelementbase import FiniteElementBase
from ufl.permutation import compute_indices
Expand All @@ -39,18 +40,6 @@ def __init__(self, *elements, **kwargs):
elements = [MixedElement(e) if isinstance(e, (tuple, list)) else e
for e in elements]
self._sub_elements = elements

# Pick the first cell, for now all should be equal
cells = tuple(sorted(set(element.cell for element in elements) - set([None])))
self._cells = cells
if cells:
cell = cells[0]
# Require that all elements are defined on the same cell
if not all(c == cell for c in cells[1:]):
raise ValueError("Sub elements must live on the same cell.")
else:
cell = None

# Check that all elements use the same quadrature scheme TODO:
# We can allow the scheme not to be defined.
if len(elements) == 0:
Expand All @@ -70,9 +59,16 @@ def __init__(self, *elements, **kwargs):
# Initialize element data
degrees = {e.degree() for e in self._sub_elements} - {None}
degree = max_degree(degrees) if degrees else None
FiniteElementBase.__init__(self, "Mixed", cell, degree, quad_scheme,
FiniteElementBase.__init__(self, "Mixed", self._make_cell(), degree, quad_scheme,
reference_value_shape)

def _make_cell(self):
if self.num_sub_elements == 0:
return
else:
cells = tuple(e.cell for e in self.sub_elements)
return CellSequence(cells)

def __repr__(self):
"""Doc."""
return "MixedElement(" + ", ".join(repr(e) for e in self._sub_elements) + ")"
Expand All @@ -94,6 +90,8 @@ def symmetry(self, domain):
:math:`c_1`.
A component is a tuple of one or more ints.
"""
if isinstance(domain, MeshSequence):
raise NotImplementedError
# Build symmetry map from symmetries of subelements
sm = {}
# Base index of the current subelement into mixed value
Expand Down Expand Up @@ -140,6 +138,8 @@ def extract_subelement_component(self, domain, i):

component index for a given component index.
"""
if isinstance(domain, MeshSequence):
raise NotImplementedError
if isinstance(i, int):
i = (i,)
self._check_component(i)
Expand Down Expand Up @@ -245,7 +245,16 @@ def embedded_superdegree(self):

def reconstruct(self, **kwargs):
"""Doc."""
return MixedElement(*[e.reconstruct(**kwargs) for e in self.sub_elements])
cell = kwargs.pop('cell', None)
if cell is None:
cell = self.cell
else:
if not isinstance(cell, CellSequence):
# Allow for passing a single base cell.
cell = CellSequence([cell] * len(self.sub_elements))
return type(self)(
*[e.reconstruct(cell=c, **kwargs) for c, e in zip(cell.cells, self.sub_elements)],
)

def variant(self):
"""Doc."""
Expand Down Expand Up @@ -307,8 +316,10 @@ def __init__(self, family, cell=None, degree=None, dim=None,
reference_value_shape = (dim,) + sub_element.reference_value_shape

# Initialize element data
MixedElement.__init__(self, sub_elements,
reference_value_shape=reference_value_shape)
MixedElement.__init__(
self, sub_elements,
reference_value_shape=reference_value_shape,
)

FiniteElementBase.__init__(self, sub_element.family(), sub_element.cell, sub_element.degree(),
sub_element.quadrature_scheme(), reference_value_shape)
Expand All @@ -323,6 +334,13 @@ def __init__(self, family, cell=None, degree=None, dim=None,
# Cache repr string
self._repr = f"VectorElement({repr(sub_element)}, dim={dim}{var_str})"

def _make_cell(self):
if self.num_sub_elements == 0:
return
else:
cell, = set(e.cell for e in self.sub_elements)
return cell

def __repr__(self):
"""Doc."""
return self._repr
Expand Down Expand Up @@ -435,8 +453,10 @@ def __init__(self, family, cell=None, degree=None, shape=None,

reference_value_shape = reference_value_shape + sub_element.reference_value_shape
# Initialize element data
MixedElement.__init__(self, sub_elements,
reference_value_shape=reference_value_shape)
MixedElement.__init__(
self, sub_elements,
reference_value_shape=reference_value_shape,
)
self._family = sub_element.family()
self._degree = sub_element.degree()
self._sub_element = sub_element
Expand All @@ -454,6 +474,13 @@ def __init__(self, family, cell=None, degree=None, shape=None,
self._repr = (f"TensorElement({repr(sub_element)}, shape={shape}, "
f"symmetry={symmetry}{var_str})")

def _make_cell(self):
if self.num_sub_elements == 0:
return
else:
cell, = set(e.cell for e in self.sub_elements)
return cell

@property
def pullback(self):
"""Get pull back."""
Expand Down