Skip to content

Commit 6c564ac

Browse files
committed
introduce MixedMesh
1 parent 69cce58 commit 6c564ac

21 files changed

+546
-191
lines changed

firedrake/assemble.py

Lines changed: 114 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ def parloop_builders(self):
10201020
self._bcs,
10211021
local_kernel,
10221022
subdomain_id,
1023-
self.all_integer_subdomain_ids,
1023+
self.all_integer_subdomain_ids[local_kernel.kinfo.domain_number],
10241024
diagonal=self.diagonal,
10251025
)
10261026
)
@@ -1037,15 +1037,15 @@ def local_kernels(self):
10371037
each possible combination.
10381038
10391039
"""
1040-
try:
1041-
topology, = set(d.topology for d in self._form.ufl_domains())
1042-
except ValueError:
1043-
raise NotImplementedError("All integration domains must share a mesh topology")
1040+
#try:
1041+
# topology, = set(d.topology for d in self._form.ufl_domains())
1042+
#except ValueError:
1043+
# raise NotImplementedError("All integration domains must share a mesh topology")
10441044

1045-
for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046-
domain = extract_unique_domain(o)
1047-
if domain is not None and domain.topology != topology:
1048-
raise NotImplementedError("Assembly with multiple meshes is not supported")
1045+
#for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046+
# domain = extract_unique_domain(o)
1047+
# if domain is not None and domain.topology != topology:
1048+
# raise NotImplementedError("Assembly with multiple meshes is not supported")
10491049

10501050
if isinstance(self._form, ufl.Form):
10511051
kernels = tsfc_interface.compile_form(
@@ -1345,7 +1345,7 @@ def _make_maps_and_regions(self):
13451345
return ExplicitMatrixAssembler._make_maps_and_regions_default(test, trial, allocation_integral_types)
13461346
else:
13471347
maps_and_regions = defaultdict(lambda: defaultdict(set))
1348-
for local_kernel in self._all_local_kernels:
1348+
for local_kernel in self._all_local_kernels:
13491349
i, j = local_kernel.indices
13501350
# Make Sparsity independent of _iterset, which can be a Subset, for better reusability.
13511351
integral_type = local_kernel.kinfo.integral_type
@@ -1565,8 +1565,13 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
15651565
self._diagonal = diagonal
15661566
self._unroll = unroll
15671567

1568+
self._active_coordinates = _FormHandler.iter_active_coordinates(form, local_knl.kinfo)
1569+
self._active_cell_orientations = _FormHandler.iter_active_cell_orientations(form, local_knl.kinfo)
1570+
self._active_cell_sizes = _FormHandler.iter_active_cell_sizes(form, local_knl.kinfo)
15681571
self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo)
15691572
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)
1573+
self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo)
1574+
self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo)
15701575

15711576
self._map_arg_cache = {}
15721577
# Cache for holding :class:`op2.MapKernelArg` instances.
@@ -1580,8 +1585,13 @@ def build(self):
15801585
for arg in self._kinfo.arguments]
15811586

15821587
# we should use up all of the coefficients and constants
1588+
assert_empty(self._active_coordinates)
1589+
assert_empty(self._active_cell_orientations)
1590+
assert_empty(self._active_cell_sizes)
15831591
assert_empty(self._active_coefficients)
15841592
assert_empty(self._constants)
1593+
assert_empty(self._active_exterior_facets)
1594+
assert_empty(self._active_interior_facets)
15851595

15861596
iteration_regions = {"exterior_facet_top": op2.ON_TOP,
15871597
"exterior_facet_bottom": op2.ON_BOTTOM,
@@ -1606,7 +1616,8 @@ def _integral_type(self):
16061616

16071617
@cached_property
16081618
def _mesh(self):
1609-
return self._form.ufl_domains()[self._kinfo.domain_number]
1619+
all_meshes = extract_domains(self._form)
1620+
return all_meshes[self._kinfo.domain_number]
16101621

16111622
@cached_property
16121623
def _needs_subset(self):
@@ -1711,7 +1722,22 @@ def _as_global_kernel_arg_output(_, self):
17111722

17121723
@_as_global_kernel_arg.register(kernel_args.CoordinatesKernelArg)
17131724
def _as_global_kernel_arg_coordinates(_, self):
1714-
V = self._mesh.coordinates.function_space()
1725+
coord = next(self._active_coordinates)
1726+
V = coord.function_space()
1727+
return self._make_dat_global_kernel_arg(V)
1728+
1729+
1730+
@_as_global_kernel_arg.register(kernel_args.CellOrientationsKernelArg)
1731+
def _as_global_kernel_arg_cell_orientations(_, self):
1732+
c = next(self._active_cell_orientations)
1733+
V = c.function_space()
1734+
return self._make_dat_global_kernel_arg(V)
1735+
1736+
1737+
@_as_global_kernel_arg.register(kernel_args.CellSizesKernelArg)
1738+
def _as_global_kernel_arg_cell_sizes(_, self):
1739+
c = next(self._active_cell_sizes)
1740+
V = c.function_space()
17151741
return self._make_dat_global_kernel_arg(V)
17161742

17171743

@@ -1739,19 +1765,15 @@ def _as_global_kernel_arg_constant(_, self):
17391765
return op2.GlobalKernelArg((value_size,))
17401766

17411767

1742-
@_as_global_kernel_arg.register(kernel_args.CellSizesKernelArg)
1743-
def _as_global_kernel_arg_cell_sizes(_, self):
1744-
V = self._mesh.cell_sizes.function_space()
1745-
return self._make_dat_global_kernel_arg(V)
1746-
1747-
17481768
@_as_global_kernel_arg.register(kernel_args.ExteriorFacetKernelArg)
17491769
def _as_global_kernel_arg_exterior_facet(_, self):
1770+
_ = next(self._active_exterior_facets)
17501771
return op2.DatKernelArg((1,))
17511772

17521773

17531774
@_as_global_kernel_arg.register(kernel_args.InteriorFacetKernelArg)
17541775
def _as_global_kernel_arg_interior_facet(_, self):
1776+
_ = next(self._active_interior_facets)
17551777
return op2.DatKernelArg((2,))
17561778

17571779

@@ -1764,12 +1786,6 @@ def _as_global_kernel_arg_cell_facet(_, self):
17641786
return op2.DatKernelArg((num_facets, 2))
17651787

17661788

1767-
@_as_global_kernel_arg.register(kernel_args.CellOrientationsKernelArg)
1768-
def _as_global_kernel_arg_cell_orientations(_, self):
1769-
V = self._mesh.cell_orientations().function_space()
1770-
return self._make_dat_global_kernel_arg(V)
1771-
1772-
17731789
@_as_global_kernel_arg.register(LayerCountKernelArg)
17741790
def _as_global_kernel_arg_layer_count(_, self):
17751791
return op2.GlobalKernelArg((1,))
@@ -1803,8 +1819,13 @@ def __init__(self, form, bcs, local_knl, subdomain_id,
18031819
self._diagonal = diagonal
18041820
self._bcs = bcs
18051821

1822+
self._active_coordinates = _FormHandler.iter_active_coordinates(form, local_knl.kinfo)
1823+
self._active_cell_orientations = _FormHandler.iter_active_cell_orientations(form, local_knl.kinfo)
1824+
self._active_cell_sizes = _FormHandler.iter_active_cell_sizes(form, local_knl.kinfo)
18061825
self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo)
18071826
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)
1827+
self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo)
1828+
self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo)
18081829

18091830
def build(self, tensor):
18101831
"""Construct the parloop.
@@ -1931,7 +1952,8 @@ def _indexed_tensor(self):
19311952

19321953
@cached_property
19331954
def _mesh(self):
1934-
return self._form.ufl_domains()[self._kinfo.domain_number]
1955+
all_meshes = extract_domains(self._form)
1956+
return all_meshes[self._kinfo.domain_number]
19351957

19361958
@cached_property
19371959
def _iterset(self):
@@ -2004,8 +2026,22 @@ def _as_parloop_arg_output(_, self):
20042026

20052027
@_as_parloop_arg.register(kernel_args.CoordinatesKernelArg)
20062028
def _as_parloop_arg_coordinates(_, self):
2007-
func = self._mesh.coordinates
2008-
map_ = self._get_map(func.function_space())
2029+
func = next(self._active_coordinates)
2030+
map_ = self._get_map(func.function_space()) #Compose!
2031+
return op2.DatParloopArg(func.dat, map_)
2032+
2033+
2034+
@_as_parloop_arg.register(kernel_args.CellOrientationsKernelArg)
2035+
def _as_parloop_arg_cell_orientations(_, self):
2036+
func = next(self._active_cell_orientations)
2037+
map_ = self._get_map(func.function_space()) #Compose!
2038+
return op2.DatParloopArg(func.dat, map_)
2039+
2040+
2041+
@_as_parloop_arg.register(kernel_args.CellSizesKernelArg)
2042+
def _as_parloop_arg_cell_sizes(_, self):
2043+
func = next(self._active_cell_sizes)
2044+
map_ = self._get_map(func.function_space()) #Compose!
20092045
return op2.DatParloopArg(func.dat, map_)
20102046

20112047

@@ -2025,28 +2061,26 @@ def _as_parloop_arg_constant(arg, self):
20252061
return op2.GlobalParloopArg(const.dat)
20262062

20272063

2028-
@_as_parloop_arg.register(kernel_args.CellOrientationsKernelArg)
2029-
def _as_parloop_arg_cell_orientations(_, self):
2030-
func = self._mesh.cell_orientations()
2031-
m = self._get_map(func.function_space())
2032-
return op2.DatParloopArg(func.dat, m)
2033-
2034-
2035-
@_as_parloop_arg.register(kernel_args.CellSizesKernelArg)
2036-
def _as_parloop_arg_cell_sizes(_, self):
2037-
func = self._mesh.cell_sizes
2038-
m = self._get_map(func.function_space())
2039-
return op2.DatParloopArg(func.dat, m)
2040-
2041-
20422064
@_as_parloop_arg.register(kernel_args.ExteriorFacetKernelArg)
20432065
def _as_parloop_arg_exterior_facet(_, self):
2044-
return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_dat)
2066+
mesh, local_facet_dat = next(self._active_exterior_facets)
2067+
if mesh is self._mesh:
2068+
m = None
2069+
else:
2070+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
2071+
assert integral_type == "exterior_facets"
2072+
return op2.DatParloopArg(local_facet_dat, m)
20452073

20462074

20472075
@_as_parloop_arg.register(kernel_args.InteriorFacetKernelArg)
20482076
def _as_parloop_arg_interior_facet(_, self):
2049-
return op2.DatParloopArg(self._mesh.interior_facets.local_facet_dat)
2077+
mesh, local_facet_dat = next(self._active_interior_facets)
2078+
if mesh is self._mesh:
2079+
m = None
2080+
else:
2081+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
2082+
assert integral_type == "interior_facets"
2083+
return op2.DatParloopArg(local_facet_dat, m)
20502084

20512085

20522086
@_as_parloop_arg.register(CellFacetKernelArg)
@@ -2068,6 +2102,27 @@ def _as_parloop_arg_layer_count(_, self):
20682102
class _FormHandler:
20692103
"""Utility class for inspecting forms and local kernels."""
20702104

2105+
@staticmethod
2106+
def iter_active_coordinates(form, kinfo):
2107+
"""Yield the form coordinates referenced in ``kinfo``."""
2108+
all_meshes = extract_domains(form)
2109+
for i in kinfo.active_domain_numbers.coordinates:
2110+
yield all_meshes[i].coordinates
2111+
2112+
@staticmethod
2113+
def iter_active_cell_orientations(form, kinfo):
2114+
"""Yield the form cell orientations referenced in ``kinfo``."""
2115+
all_meshes = extract_domains(form)
2116+
for i in kinfo.active_domain_numbers.cell_orientations:
2117+
yield all_meshes[i].cell_orientations()
2118+
2119+
@staticmethod
2120+
def iter_active_cell_sizes(form, kinfo):
2121+
"""Yield the form cell sizes referenced in ``kinfo``."""
2122+
all_meshes = extract_domains(form)
2123+
for i in kinfo.active_domain_numbers.cell_sizes:
2124+
yield all_meshes[i].cell_sizes
2125+
20712126
@staticmethod
20722127
def iter_active_coefficients(form, kinfo):
20732128
"""Yield the form coefficients referenced in ``kinfo``."""
@@ -2087,6 +2142,22 @@ def iter_constants(form, kinfo):
20872142
for constant_index in kinfo.constant_numbers:
20882143
yield all_constants[constant_index]
20892144

2145+
@staticmethod
2146+
def iter_active_exterior_facets(form, kinfo):
2147+
"""Yield the form exterior facets referenced in ``kinfo``."""
2148+
all_meshes = extract_domains(form)
2149+
for i in kinfo.active_domain_numbers.exterior_facets:
2150+
mesh = all_meshes[i]
2151+
yield mesh, mesh.exterior_facets.local_facet_dat
2152+
2153+
@staticmethod
2154+
def iter_active_interior_facets(form, kinfo):
2155+
"""Yield the form interior facets referenced in ``kinfo``."""
2156+
all_meshes = extract_domains(form)
2157+
for i in kinfo.active_domain_numbers.interior_facets:
2158+
mesh = all_meshes[i]
2159+
yield mesh, mesh.interior_facets.local_facet_dat
2160+
20902161
@staticmethod
20912162
def index_function_spaces(form, indices):
20922163
"""Return the function spaces of the form's arguments, indexed

firedrake/checkpointing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,8 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
565565
:kwarg distribution_name: the name under which distribution is saved; if `None`, auto-generated name will be used.
566566
:kwarg permutation_name: the name under which permutation is saved; if `None`, auto-generated name will be used.
567567
"""
568+
# TODO: Add general MixedMesh support.
569+
mesh = mesh.unique()
568570
mesh.init()
569571
# Handle extruded mesh
570572
tmesh = mesh.topology
@@ -827,6 +829,8 @@ def get_timestepping_history(self, mesh, name):
827829
@PETSc.Log.EventDecorator("SaveFunctionSpace")
828830
def _save_function_space(self, V):
829831
mesh = V.mesh()
832+
# TODO: Add general MixedMesh support.
833+
mesh = mesh.unique()
830834
if isinstance(V.topological, impl.MixedFunctionSpace):
831835
V_name = self._generate_function_space_name(V)
832836
base_path = self._path_to_mixed_function_space(mesh.name, V_name)
@@ -902,10 +906,12 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
902906
each index.
903907
"""
904908
V = f.function_space()
905-
mesh = V.mesh()
906909
if name:
907910
g = Function(V, val=f.dat, name=name)
908911
return self.save_function(g, idx=idx, timestepping_info=timestepping_info)
912+
mesh = V.mesh()
913+
# TODO: Add general MixedMesh support.
914+
mesh = mesh.unique()
909915
# -- Save function space --
910916
self._save_function_space(V)
911917
# -- Save function --
@@ -1224,6 +1230,8 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
12241230

12251231
@PETSc.Log.EventDecorator("LoadFunctionSpace")
12261232
def _load_function_space(self, mesh, name):
1233+
# TODO: Add general MixedMesh support.
1234+
mesh = mesh.unique()
12271235
mesh.init()
12281236
mesh_key = self._generate_mesh_key_from_names(mesh.name,
12291237
mesh.topology._distribution_name,
@@ -1301,6 +1309,8 @@ def load_function(self, mesh, name, idx=None):
13011309
be loaded with idx only when it was saved with idx.
13021310
:returns: the loaded :class:`~.Function`.
13031311
"""
1312+
# TODO: Add general MixedMesh support.
1313+
mesh = mesh.unique()
13041314
tmesh = mesh.topology
13051315
if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name):
13061316
V_name = self._get_mixed_function_name_mixed_function_space_name_map(mesh.name)[name]

firedrake/dmhooks.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
import firedrake
4545
from firedrake.petsc import PETSc
46-
46+
from firedrake.mesh import MixedMeshGeometry
4747

4848
@PETSc.Log.EventDecorator()
4949
def get_function_space(dm):
@@ -53,8 +53,11 @@ def get_function_space(dm):
5353
:raises RuntimeError: if no function space was found.
5454
"""
5555
info = dm.getAttr("__fs_info__")
56-
meshref, element, indices, (name, names) = info
57-
mesh = meshref()
56+
meshref_tuple, element, indices, (name, names) = info
57+
if len(meshref_tuple) == 1:
58+
mesh = meshref_tuple[0]()
59+
else:
60+
mesh = MixedMeshGeometry(*(meshref() for meshref in meshref_tuple))
5861
if mesh is None:
5962
raise RuntimeError("Somehow your mesh was collected, this should never happen")
6063
V = firedrake.FunctionSpace(mesh, element, name=name)
@@ -78,8 +81,6 @@ def set_function_space(dm, V):
7881
This stores the information necessary to make a function space given a DM.
7982
8083
"""
81-
mesh = V.mesh()
82-
8384
indices = []
8485
names = []
8586
while V.parent is not None:
@@ -90,11 +91,12 @@ def set_function_space(dm, V):
9091
assert V.index is None
9192
indices.append(V.component)
9293
V = V.parent
94+
mesh = V.mesh()
9395
if len(V) > 1:
9496
names = tuple(V_.name for V_ in V)
9597
element = V.ufl_element()
9698

97-
info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names))
99+
info = (tuple(weakref.ref(m) for m in mesh), element, tuple(reversed(indices)), (V.name, names))
98100
dm.setAttr("__fs_info__", info)
99101

100102

@@ -412,7 +414,9 @@ def coarsen(dm, comm):
412414
"""
413415
from firedrake.mg.utils import get_level
414416
V = get_function_space(dm)
415-
hierarchy, level = get_level(V.mesh())
417+
# TODO: Think harder.
418+
m, = set(m_ for m_ in V.mesh())
419+
hierarchy, level = get_level(m)
416420
if level < 1:
417421
raise RuntimeError("Cannot coarsen coarsest DM")
418422
coarsen = get_ctx_coarsener(dm)

0 commit comments

Comments
 (0)