-
Notifications
You must be signed in to change notification settings - Fork 183
Expand file tree
/
Copy pathembedding.py
More file actions
68 lines (60 loc) · 2.86 KB
/
embedding.py
File metadata and controls
68 lines (60 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-
"""Module for utility functions for scalable HDF5 I/O."""
import finat.ufl
import ufl
def get_embedding_dg_element(element, value_shape, broken_cg=False):
cell, = set(element.cell.cells)
family = lambda c: "DG" if c.is_simplex else "DQ"
if isinstance(cell, ufl.TensorProductCell):
degree = element.degree()
if type(degree) is int:
scalar_element = finat.ufl.FiniteElement("DQ", cell=cell, degree=degree)
else:
scalar_element = finat.ufl.TensorProductElement(*(finat.ufl.FiniteElement(family(c), cell=c, degree=d)
for (c, d) in zip(cell.sub_cells, degree)))
else:
degree = element.embedded_superdegree
scalar_element = finat.ufl.FiniteElement(family(cell), cell=cell, degree=degree)
if broken_cg:
scalar_element = finat.ufl.BrokenElement(scalar_element.reconstruct(family="Lagrange"))
shape = value_shape
if len(shape) == 0:
DG = scalar_element
elif len(shape) == 1:
shape, = shape
DG = finat.ufl.VectorElement(scalar_element, dim=shape)
else:
if isinstance(element, finat.ufl.TensorElement):
symmetry = element.symmetry()
else:
symmetry = None
DG = finat.ufl.TensorElement(scalar_element, shape=shape, symmetry=symmetry)
return DG
native_elements_for_checkpointing = {"Lagrange", "Discontinuous Lagrange", "Q", "DQ", "Real"}
def get_embedding_element_for_checkpointing(element, value_shape):
"""Convert the given UFL element to an element that :class:`~.CheckpointFile` can handle."""
if element.family() in native_elements_for_checkpointing:
return element
else:
return get_embedding_dg_element(element, value_shape)
def get_embedding_method_for_checkpointing(element):
"""Return the method used to embed element in dg space."""
if isinstance(element, (finat.ufl.HDivElement, finat.ufl.HCurlElement, finat.ufl.WithMapping)):
return "project"
elif isinstance(element, (finat.ufl.VectorElement, finat.ufl.TensorElement)):
elem, = set(element.sub_elements)
return get_embedding_method_for_checkpointing(elem)
elif element.family() in ['Lagrange', 'Discontinuous Lagrange',
'Nedelec 1st kind H(curl)', 'Raviart-Thomas',
'Nedelec 2nd kind H(curl)', 'Brezzi-Douglas-Marini',
'Q', 'DQ',
'S', 'DPC', 'Real']:
return "interpolate"
elif isinstance(element, finat.ufl.TensorProductElement):
methods = [get_embedding_method_for_checkpointing(elem) for elem in element.factor_elements]
if any(method == "project" for method in methods):
return "project"
else:
return "interpolate"
else:
return "project"