diff --git a/finat/ufl/mixedelement.py b/finat/ufl/mixedelement.py index df45bce9..20c25353 100644 --- a/finat/ufl/mixedelement.py +++ b/finat/ufl/mixedelement.py @@ -13,7 +13,8 @@ import numpy as np -from ufl.cell import as_cell +from ufl.cell import CellSequence, as_cell +from ufl.domain import MeshSequence from finat.ufl.finiteelement import FiniteElement from finat.ufl.finiteelementbase import FiniteElementBase from ufl.permutation import compute_indices @@ -39,18 +40,6 @@ def __init__(self, *elements, **kwargs): elements = [MixedElement(e) if isinstance(e, (tuple, list)) else e for e in elements] self._sub_elements = elements - - # Pick the first cell, for now all should be equal - cells = tuple(sorted(set(element.cell for element in elements) - set([None]))) - self._cells = cells - if cells: - cell = cells[0] - # Require that all elements are defined on the same cell - if not all(c == cell for c in cells[1:]): - raise ValueError("Sub elements must live on the same cell.") - else: - cell = None - # Check that all elements use the same quadrature scheme TODO: # We can allow the scheme not to be defined. if len(elements) == 0: @@ -70,9 +59,16 @@ def __init__(self, *elements, **kwargs): # Initialize element data degrees = {e.degree() for e in self._sub_elements} - {None} degree = max_degree(degrees) if degrees else None - FiniteElementBase.__init__(self, "Mixed", cell, degree, quad_scheme, + FiniteElementBase.__init__(self, "Mixed", self._make_cell(), degree, quad_scheme, reference_value_shape) + def _make_cell(self): + if self.num_sub_elements == 0: + return + else: + cells = tuple(e.cell for e in self.sub_elements) + return CellSequence(cells) + def __repr__(self): """Doc.""" return "MixedElement(" + ", ".join(repr(e) for e in self._sub_elements) + ")" @@ -94,6 +90,8 @@ def symmetry(self, domain): :math:`c_1`. A component is a tuple of one or more ints. """ + if isinstance(domain, MeshSequence): + raise NotImplementedError # Build symmetry map from symmetries of subelements sm = {} # Base index of the current subelement into mixed value @@ -140,6 +138,8 @@ def extract_subelement_component(self, domain, i): component index for a given component index. """ + if isinstance(domain, MeshSequence): + raise NotImplementedError if isinstance(i, int): i = (i,) self._check_component(i) @@ -245,7 +245,16 @@ def embedded_superdegree(self): def reconstruct(self, **kwargs): """Doc.""" - return MixedElement(*[e.reconstruct(**kwargs) for e in self.sub_elements]) + cell = kwargs.pop('cell', None) + if cell is None: + cell = self.cell + else: + if not isinstance(cell, CellSequence): + # Allow for passing a single base cell. + cell = CellSequence([cell] * len(self.sub_elements)) + return type(self)( + *[e.reconstruct(cell=c, **kwargs) for c, e in zip(cell.cells, self.sub_elements)], + ) def variant(self): """Doc.""" @@ -307,8 +316,10 @@ def __init__(self, family, cell=None, degree=None, dim=None, reference_value_shape = (dim,) + sub_element.reference_value_shape # Initialize element data - MixedElement.__init__(self, sub_elements, - reference_value_shape=reference_value_shape) + MixedElement.__init__( + self, sub_elements, + reference_value_shape=reference_value_shape, + ) FiniteElementBase.__init__(self, sub_element.family(), sub_element.cell, sub_element.degree(), sub_element.quadrature_scheme(), reference_value_shape) @@ -323,6 +334,13 @@ def __init__(self, family, cell=None, degree=None, dim=None, # Cache repr string self._repr = f"VectorElement({repr(sub_element)}, dim={dim}{var_str})" + def _make_cell(self): + if self.num_sub_elements == 0: + return + else: + cell, = set(e.cell for e in self.sub_elements) + return cell + def __repr__(self): """Doc.""" return self._repr @@ -435,8 +453,10 @@ def __init__(self, family, cell=None, degree=None, shape=None, reference_value_shape = reference_value_shape + sub_element.reference_value_shape # Initialize element data - MixedElement.__init__(self, sub_elements, - reference_value_shape=reference_value_shape) + MixedElement.__init__( + self, sub_elements, + reference_value_shape=reference_value_shape, + ) self._family = sub_element.family() self._degree = sub_element.degree() self._sub_element = sub_element @@ -454,6 +474,13 @@ def __init__(self, family, cell=None, degree=None, shape=None, self._repr = (f"TensorElement({repr(sub_element)}, shape={shape}, " f"symmetry={symmetry}{var_str})") + def _make_cell(self): + if self.num_sub_elements == 0: + return + else: + cell, = set(e.cell for e in self.sub_elements) + return cell + @property def pullback(self): """Get pull back."""