Skip to content

Commit 533f482

Browse files
committed
submesh: enable cell submesh
1 parent ce7b8ad commit 533f482

File tree

11 files changed

+1884
-138
lines changed

11 files changed

+1884
-138
lines changed

firedrake/__future__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ufl.domain import as_domain, extract_unique_domain
22
from ufl.algorithms import extract_arguments
3-
from firedrake.mesh import VertexOnlyMeshTopology
3+
from firedrake.mesh import MeshTopology, VertexOnlyMeshTopology
44
from firedrake.interpolation import (interpolate as interpolate_old,
55
Interpolator as InterpolatorOld,
66
SameMeshInterpolator as SameMeshInterpolatorOld,
@@ -16,13 +16,13 @@ class Interpolator(InterpolatorOld):
1616
def __new__(cls, expr, V, **kwargs):
1717
target_mesh = as_domain(V)
1818
source_mesh = extract_unique_domain(expr) or target_mesh
19-
if target_mesh is not source_mesh:
19+
if target_mesh is source_mesh or all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh]) and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]:
20+
return object.__new__(SameMeshInterpolator)
21+
else:
2022
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
2123
return object.__new__(SameMeshInterpolator)
2224
else:
2325
return object.__new__(CrossMeshInterpolator)
24-
else:
25-
return object.__new__(SameMeshInterpolator)
2626

2727
interpolate = InterpolatorOld._interpolate_future
2828

firedrake/assemble.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
1919
tsfc_interface, utils)
2020
from firedrake.adjoint_utils import annotate_assemble
21-
from firedrake.ufl_expr import extract_unique_domain
21+
from firedrake.ufl_expr import extract_domains
2222
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
2323
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace
2424
from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key
@@ -1037,15 +1037,16 @@ def local_kernels(self):
10371037
each possible combination.
10381038
10391039
"""
1040-
#try:
1041-
# topology, = set(d.topology for d in self._form.ufl_domains())
1042-
#except ValueError:
1043-
# raise NotImplementedError("All integration domains must share a mesh topology")
1040+
try:
1041+
topology, = set(d.topology.submesh_ancesters[-1] for d in self._form.ufl_domains())
1042+
except ValueError:
1043+
raise NotImplementedError("All integration domains must share a mesh topology")
10441044

1045-
#for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046-
# domain = extract_unique_domain(o)
1047-
# if domain is not None and domain.topology != topology:
1048-
# raise NotImplementedError("Assembly with multiple meshes is not supported")
1045+
for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046+
domains = extract_domains(o)
1047+
for domain in domains:
1048+
if domain is not None and domain.topology.submesh_ancesters[-1] != topology:
1049+
raise NotImplementedError("Assembly with multiple meshes is not supported")
10491050

10501051
if isinstance(self._form, ufl.Form):
10511052
kernels = tsfc_interface.compile_form(
@@ -1337,20 +1338,23 @@ def _make_maps_and_regions(self):
13371338
test, trial = self._form.arguments()
13381339
if self._allocation_integral_types is not None:
13391340
return ExplicitMatrixAssembler._make_maps_and_regions_default(test, trial, self._allocation_integral_types)
1340-
elif any(local_kernel.indices == (None, None) for local_kernel in self._all_local_kernels):
1341+
elif any(local_kernel.indices == (None, None) for local_kernel, _ in self._all_local_kernels):
13411342
# Handle special cases: slate or split=False
1342-
assert all(local_kernel.indices == (None, None) for local_kernel in self._all_local_kernels)
1343+
assert all(local_kernel.indices == (None, None) for local_kernel, _ in self._all_local_kernels)
13431344
allocation_integral_types = set(local_kernel.kinfo.integral_type
1344-
for local_kernel in self._all_local_kernels)
1345+
for local_kernel, _ in self._all_local_kernels)
13451346
return ExplicitMatrixAssembler._make_maps_and_regions_default(test, trial, allocation_integral_types)
13461347
else:
13471348
maps_and_regions = defaultdict(lambda: defaultdict(set))
1348-
for local_kernel in self._all_local_kernels:
1349+
all_meshes = extract_domains(self._form)
1350+
for local_kernel, subdomain_id in self._all_local_kernels:
13491351
i, j = local_kernel.indices
1352+
mesh = all_meshes[local_kernel.kinfo.domain_number] # integration domain
13501353
# Make Sparsity independent of _iterset, which can be a Subset, for better reusability.
13511354
integral_type = local_kernel.kinfo.integral_type
1352-
rmap_ = test.function_space().topological[i].entity_node_map(integral_type)
1353-
cmap_ = trial.function_space().topological[j].entity_node_map(integral_type)
1355+
all_subdomain_ids = self.all_integer_subdomain_ids[local_kernel.kinfo.domain_number]
1356+
rmap_ = test.function_space().topological[i].entity_node_map(mesh.topology, integral_type, subdomain_id, all_subdomain_ids)
1357+
cmap_ = trial.function_space().topological[j].entity_node_map(mesh.topology, integral_type, subdomain_id, all_subdomain_ids)
13541358
region = ExplicitMatrixAssembler._integral_type_region_map[integral_type]
13551359
maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
13561360
return {block_indices: [map_pair + (tuple(region_set), ) for map_pair, region_set in map_pair_to_region_set.items()]
@@ -1366,8 +1370,14 @@ def _make_maps_and_regions_default(test, trial, allocation_integral_types):
13661370
# Use outer product of component maps.
13671371
for integral_type in allocation_integral_types:
13681372
region = ExplicitMatrixAssembler._integral_type_region_map[integral_type]
1369-
for i, rmap_ in enumerate(test.function_space().topological.entity_node_map(integral_type)):
1370-
for j, cmap_ in enumerate(trial.function_space().topological.entity_node_map(integral_type)):
1373+
#for i, rmap_ in enumerate(test.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1374+
# for j, cmap_ in enumerate(trial.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1375+
# maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
1376+
for i, Vrow in enumerate(test.function_space()):
1377+
for j, Vcol in enumerate(trial.function_space()):
1378+
mesh = Vrow.mesh()
1379+
rmap_ = Vrow.topological.entity_node_map(mesh.topology, integral_type, None, None)
1380+
cmap_ = Vcol.topological.entity_node_map(mesh.topology, integral_type, None, None)
13711381
maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
13721382
return {block_indices: [map_pair + (tuple(region_set), ) for map_pair, region_set in map_pair_to_region_set.items()]
13731383
for block_indices, map_pair_to_region_set in maps_and_regions.items()}
@@ -1389,7 +1399,7 @@ def _all_local_kernels(self):
13891399
When constructing sparsity, we use all parloop_builders
13901400
that are to be used in the actual assembly.
13911401
"""
1392-
all_local_kernels = tuple(local_kernel for local_kernel, _ in self.local_kernels)
1402+
all_local_kernels = self.local_kernels
13931403
for bc in self._bcs:
13941404
if isinstance(bc, EquationBCSplit):
13951405
_assembler = type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False)
@@ -1559,7 +1569,7 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
15591569
self._form = form
15601570
self._indices, self._kinfo = local_knl
15611571
self._subdomain_id = subdomain_id
1562-
self._all_integer_subdomain_ids = all_integer_subdomain_ids.get(self._kinfo.integral_type, None)
1572+
self._all_integer_subdomain_ids = all_integer_subdomain_ids
15631573
self._diagonal = diagonal
15641574
self._unroll = unroll
15651575

@@ -1626,7 +1636,7 @@ def _needs_subset(self):
16261636
if self._subdomain_id == "everywhere":
16271637
return False
16281638
elif self._subdomain_id == "otherwise":
1629-
return self._all_integer_subdomain_ids is not None
1639+
return self._all_integer_subdomain_ids.get(self._kinfo.integral_type, None) is not None
16301640
else:
16311641
return True
16321642

@@ -1646,7 +1656,7 @@ def _get_dim(self, finat_element):
16461656

16471657
def _make_dat_global_kernel_arg(self, V, index=None):
16481658
finat_element = create_element(V.ufl_element())
1649-
map_arg = V.topological.entity_node_map(self._integral_type)._global_kernel_arg
1659+
map_arg = V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)._global_kernel_arg
16501660
if isinstance(finat_element, finat.EnrichedElement) and finat_element.is_mixed:
16511661
assert index is None
16521662
subargs = tuple(self._make_dat_global_kernel_arg(Vsub, index=index)
@@ -1664,7 +1674,7 @@ def _make_mat_global_kernel_arg(self, Vrow, Vcol):
16641674
shape = len(relem.elements), len(celem.elements)
16651675
return op2.MixedMatKernelArg(subargs, shape)
16661676
else:
1667-
rmap_arg, cmap_arg = (V.topological.entity_node_map(self._integral_type)._global_kernel_arg for V in [Vrow, Vcol])
1677+
rmap_arg, cmap_arg = (V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)._global_kernel_arg for V in [Vrow, Vcol])
16681678
# PyOP2 matrix objects have scalar dims so we flatten them here
16691679
rdim = numpy.prod(self._get_dim(relem), dtype=int)
16701680
cdim = numpy.prod(self._get_dim(celem), dtype=int)
@@ -1765,14 +1775,24 @@ def _as_global_kernel_arg_constant(_, self):
17651775

17661776
@_as_global_kernel_arg.register(kernel_args.ExteriorFacetKernelArg)
17671777
def _as_global_kernel_arg_exterior_facet(_, self):
1768-
_ = next(self._active_exterior_facets)
1769-
return op2.DatKernelArg((1,))
1778+
mesh, _ = next(self._active_exterior_facets)
1779+
if mesh is self._mesh:
1780+
return op2.DatKernelArg((1,))
1781+
else:
1782+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
1783+
assert integral_type == "exterior_facet"
1784+
return op2.DatKernelArg((1,), m._global_kernel_arg)
17701785

17711786

17721787
@_as_global_kernel_arg.register(kernel_args.InteriorFacetKernelArg)
17731788
def _as_global_kernel_arg_interior_facet(_, self):
1774-
_ = next(self._active_interior_facets)
1775-
return op2.DatKernelArg((2,))
1789+
mesh, _ = next(self._active_interior_facets)
1790+
if mesh is self._mesh:
1791+
return op2.DatKernelArg((2,))
1792+
else:
1793+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
1794+
assert integral_type == "interior_facet"
1795+
return op2.DatKernelArg((2,), m._global_kernel_arg)
17761796

17771797

17781798
@_as_global_kernel_arg.register(CellFacetKernelArg)
@@ -1978,7 +1998,7 @@ def _iterset(self):
19781998
def _get_map(self, V):
19791999
"""Return the appropriate PyOP2 map for a given function space."""
19802000
assert isinstance(V, (WithGeometry, FiredrakeDualSpace, FunctionSpace))
1981-
return V.entity_node_map(self._integral_type)
2001+
return V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
19822002

19832003
def _as_parloop_arg(self, tsfc_arg):
19842004
"""Return a :class:`op2.ParloopArg` corresponding to the provided
@@ -2066,7 +2086,7 @@ def _as_parloop_arg_exterior_facet(_, self):
20662086
m = None
20672087
else:
20682088
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
2069-
assert integral_type == "exterior_facets"
2089+
assert integral_type == "exterior_facet"
20702090
return op2.DatParloopArg(local_facet_dat, m)
20712091

20722092

@@ -2077,7 +2097,7 @@ def _as_parloop_arg_interior_facet(_, self):
20772097
m = None
20782098
else:
20792099
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
2080-
assert integral_type == "interior_facets"
2100+
assert integral_type == "interior_facet"
20812101
return op2.DatParloopArg(local_facet_dat, m)
20822102

20832103

0 commit comments

Comments
 (0)