Skip to content

Commit b061db4

Browse files
committed
submesh: enable cell submesh
1 parent 6c564ac commit b061db4

File tree

11 files changed

+1791
-138
lines changed

11 files changed

+1791
-138
lines changed

firedrake/__future__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ufl.domain import as_domain, extract_unique_domain
22
from ufl.algorithms import extract_arguments
3-
from firedrake.mesh import VertexOnlyMeshTopology
3+
from firedrake.mesh import MeshTopology, VertexOnlyMeshTopology
44
from firedrake.interpolation import (interpolate as interpolate_old,
55
Interpolator as InterpolatorOld,
66
SameMeshInterpolator as SameMeshInterpolatorOld,
@@ -16,13 +16,13 @@ class Interpolator(InterpolatorOld):
1616
def __new__(cls, expr, V, **kwargs):
1717
target_mesh = as_domain(V)
1818
source_mesh = extract_unique_domain(expr) or target_mesh
19-
if target_mesh is not source_mesh:
19+
if target_mesh is source_mesh or all(isinstance(m.topology, MeshTopology) for m in [target_mesh, source_mesh]) and target_mesh.submesh_ancesters[-1] is source_mesh.submesh_ancesters[-1]:
20+
return object.__new__(SameMeshInterpolator)
21+
else:
2022
if isinstance(target_mesh.topology, VertexOnlyMeshTopology):
2123
return object.__new__(SameMeshInterpolator)
2224
else:
2325
return object.__new__(CrossMeshInterpolator)
24-
else:
25-
return object.__new__(SameMeshInterpolator)
2626

2727
interpolate = InterpolatorOld._interpolate_future
2828

firedrake/assemble.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
1919
tsfc_interface, utils)
2020
from firedrake.adjoint_utils import annotate_assemble
21-
from firedrake.ufl_expr import extract_unique_domain
21+
from firedrake.ufl_expr import extract_domains
2222
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
2323
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace
2424
from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key
@@ -1037,15 +1037,16 @@ def local_kernels(self):
10371037
each possible combination.
10381038
10391039
"""
1040-
#try:
1041-
# topology, = set(d.topology for d in self._form.ufl_domains())
1042-
#except ValueError:
1043-
# raise NotImplementedError("All integration domains must share a mesh topology")
1040+
try:
1041+
topology, = set(d.topology.submesh_ancesters[-1] for d in self._form.ufl_domains())
1042+
except ValueError:
1043+
raise NotImplementedError("All integration domains must share a mesh topology")
10441044

1045-
#for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046-
# domain = extract_unique_domain(o)
1047-
# if domain is not None and domain.topology != topology:
1048-
# raise NotImplementedError("Assembly with multiple meshes is not supported")
1045+
for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046+
domains = extract_domains(o)
1047+
for domain in domains:
1048+
if domain is not None and domain.topology.submesh_ancesters[-1] != topology:
1049+
raise NotImplementedError("Assembly with multiple meshes is not supported")
10491050

10501051
if isinstance(self._form, ufl.Form):
10511052
kernels = tsfc_interface.compile_form(
@@ -1337,21 +1338,24 @@ def _make_maps_and_regions(self):
13371338
test, trial = self._form.arguments()
13381339
if self._allocation_integral_types is not None:
13391340
return ExplicitMatrixAssembler._make_maps_and_regions_default(test, trial, self._allocation_integral_types)
1340-
elif any(local_kernel.indices == (None, None) for local_kernel in self._all_local_kernels):
1341+
elif any(local_kernel.indices == (None, None) for local_kernel, _ in self._all_local_kernels):
13411342
# Handle special cases: slate or split=False
1342-
assert all(local_kernel.indices == (None, None) for local_kernel in self._all_local_kernels)
1343+
assert all(local_kernel.indices == (None, None) for local_kernel, _ in self._all_local_kernels)
13431344
allocation_integral_types = set(local_kernel.kinfo.integral_type
1344-
for local_kernel in self._all_local_kernels)
1345+
for local_kernel, _ in self._all_local_kernels)
13451346
return ExplicitMatrixAssembler._make_maps_and_regions_default(test, trial, allocation_integral_types)
13461347
else:
13471348
maps_and_regions = defaultdict(lambda: defaultdict(set))
1348-
for local_kernel in self._all_local_kernels:
1349+
all_meshes = extract_domains(self._form)
1350+
for local_kernel, subdomain_id in self._all_local_kernels:
13491351
i, j = local_kernel.indices
1352+
mesh = all_meshes[local_kernel.kinfo.domain_number] # integration domain
13501353
# Make Sparsity independent of _iterset, which can be a Subset, for better reusability.
13511354
integral_type = local_kernel.kinfo.integral_type
1352-
rmap_ = test.function_space().topological[i].entity_node_map(integral_type)
1355+
all_subdomain_ids = self.all_integer_subdomain_ids[local_kernel.kinfo.domain_number]
1356+
rmap_ = test.function_space().topological[i].entity_node_map(mesh.topology, integral_type, subdomain_id, all_subdomain_ids)
13531357
# rmap_ = rmap_.split[i] if rmap_ is not None else None
1354-
cmap_ = trial.function_space().topological[j].entity_node_map(integral_type)
1358+
cmap_ = trial.function_space().topological[j].entity_node_map(mesh.topology, integral_type, subdomain_id, all_subdomain_ids)
13551359
# cmap_ = cmap_.split[j] if cmap_ is not None else None
13561360
region = ExplicitMatrixAssembler._integral_type_region_map[integral_type]
13571361
maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
@@ -1368,8 +1372,14 @@ def _make_maps_and_regions_default(test, trial, allocation_integral_types):
13681372
# Use outer product of component maps.
13691373
for integral_type in allocation_integral_types:
13701374
region = ExplicitMatrixAssembler._integral_type_region_map[integral_type]
1371-
for i, rmap_ in enumerate(test.function_space().topological.entity_node_map(integral_type)):
1372-
for j, cmap_ in enumerate(trial.function_space().topological.entity_node_map(integral_type)):
1375+
#for i, rmap_ in enumerate(test.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1376+
# for j, cmap_ in enumerate(trial.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1377+
# maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
1378+
for i, Vrow in enumerate(test.function_space()):
1379+
for j, Vcol in enumerate(trial.function_space()):
1380+
mesh = Vrow.mesh()
1381+
rmap_ = Vrow.topological.entity_node_map(mesh.topology, integral_type, None, None)
1382+
cmap_ = Vcol.topological.entity_node_map(mesh.topology, integral_type, None, None)
13731383
maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
13741384
return {block_indices: [map_pair + (tuple(region_set), ) for map_pair, region_set in map_pair_to_region_set.items()]
13751385
for block_indices, map_pair_to_region_set in maps_and_regions.items()}
@@ -1391,7 +1401,7 @@ def _all_local_kernels(self):
13911401
When constructing sparsity, we use all parloop_builders
13921402
that are to be used in the actual assembly.
13931403
"""
1394-
all_local_kernels = tuple(local_kernel for local_kernel, _ in self.local_kernels)
1404+
all_local_kernels = self.local_kernels
13951405
for bc in self._bcs:
13961406
if isinstance(bc, EquationBCSplit):
13971407
_assembler = type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False)
@@ -1561,7 +1571,7 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
15611571
self._form = form
15621572
self._indices, self._kinfo = local_knl
15631573
self._subdomain_id = subdomain_id
1564-
self._all_integer_subdomain_ids = all_integer_subdomain_ids.get(self._kinfo.integral_type, None)
1574+
self._all_integer_subdomain_ids = all_integer_subdomain_ids
15651575
self._diagonal = diagonal
15661576
self._unroll = unroll
15671577

@@ -1628,7 +1638,7 @@ def _needs_subset(self):
16281638
if self._subdomain_id == "everywhere":
16291639
return False
16301640
elif self._subdomain_id == "otherwise":
1631-
return self._all_integer_subdomain_ids is not None
1641+
return self._all_integer_subdomain_ids.get(self._kinfo.integral_type, None) is not None
16321642
else:
16331643
return True
16341644

@@ -1648,7 +1658,7 @@ def _get_dim(self, finat_element):
16481658

16491659
def _make_dat_global_kernel_arg(self, V, index=None):
16501660
finat_element = create_element(V.ufl_element())
1651-
map_arg = V.topological.entity_node_map(self._integral_type)._global_kernel_arg
1661+
map_arg = V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)._global_kernel_arg
16521662
if isinstance(finat_element, finat.EnrichedElement) and finat_element.is_mixed:
16531663
assert index is None
16541664
subargs = tuple(self._make_dat_global_kernel_arg(Vsub, index=index)
@@ -1666,7 +1676,7 @@ def _make_mat_global_kernel_arg(self, Vrow, Vcol):
16661676
shape = len(relem.elements), len(celem.elements)
16671677
return op2.MixedMatKernelArg(subargs, shape)
16681678
else:
1669-
rmap_arg, cmap_arg = (V.topological.entity_node_map(self._integral_type)._global_kernel_arg for V in [Vrow, Vcol])
1679+
rmap_arg, cmap_arg = (V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)._global_kernel_arg for V in [Vrow, Vcol])
16701680
# PyOP2 matrix objects have scalar dims so we flatten them here
16711681
rdim = numpy.prod(self._get_dim(relem), dtype=int)
16721682
cdim = numpy.prod(self._get_dim(celem), dtype=int)
@@ -1767,14 +1777,24 @@ def _as_global_kernel_arg_constant(_, self):
17671777

17681778
@_as_global_kernel_arg.register(kernel_args.ExteriorFacetKernelArg)
17691779
def _as_global_kernel_arg_exterior_facet(_, self):
1770-
_ = next(self._active_exterior_facets)
1771-
return op2.DatKernelArg((1,))
1780+
mesh, _ = next(self._active_exterior_facets)
1781+
if mesh is self._mesh:
1782+
return op2.DatKernelArg((1,))
1783+
else:
1784+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
1785+
assert integral_type == "exterior_facet"
1786+
return op2.DatKernelArg((1,), m._global_kernel_arg)
17721787

17731788

17741789
@_as_global_kernel_arg.register(kernel_args.InteriorFacetKernelArg)
17751790
def _as_global_kernel_arg_interior_facet(_, self):
1776-
_ = next(self._active_interior_facets)
1777-
return op2.DatKernelArg((2,))
1791+
mesh, _ = next(self._active_interior_facets)
1792+
if mesh is self._mesh:
1793+
return op2.DatKernelArg((2,))
1794+
else:
1795+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
1796+
assert integral_type == "interior_facet"
1797+
return op2.DatKernelArg((2,), m._global_kernel_arg)
17781798

17791799

17801800
@_as_global_kernel_arg.register(CellFacetKernelArg)
@@ -1980,7 +2000,7 @@ def _iterset(self):
19802000
def _get_map(self, V):
19812001
"""Return the appropriate PyOP2 map for a given function space."""
19822002
assert isinstance(V, (WithGeometry, FiredrakeDualSpace, FunctionSpace))
1983-
return V.entity_node_map(self._integral_type)
2003+
return V.topological.entity_node_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
19842004

19852005
def _as_parloop_arg(self, tsfc_arg):
19862006
"""Return a :class:`op2.ParloopArg` corresponding to the provided
@@ -2068,7 +2088,7 @@ def _as_parloop_arg_exterior_facet(_, self):
20682088
m = None
20692089
else:
20702090
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
2071-
assert integral_type == "exterior_facets"
2091+
assert integral_type == "exterior_facet"
20722092
return op2.DatParloopArg(local_facet_dat, m)
20732093

20742094

@@ -2079,7 +2099,7 @@ def _as_parloop_arg_interior_facet(_, self):
20792099
m = None
20802100
else:
20812101
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
2082-
assert integral_type == "interior_facets"
2102+
assert integral_type == "interior_facet"
20832103
return op2.DatParloopArg(local_facet_dat, m)
20842104

20852105

0 commit comments

Comments
 (0)