1818from firedrake import (extrusion_utils as eutils , matrix , parameters , solving ,
1919 tsfc_interface , utils )
2020from firedrake .adjoint_utils import annotate_assemble
21- from firedrake .ufl_expr import extract_unique_domain
21+ from firedrake .ufl_expr import extract_domains
2222from firedrake .bcs import DirichletBC , EquationBC , EquationBCSplit
2323from firedrake .functionspaceimpl import WithGeometry , FunctionSpace , FiredrakeDualSpace
2424from firedrake .functionspacedata import entity_dofs_key , entity_permutations_key
@@ -1037,15 +1037,16 @@ def local_kernels(self):
10371037 each possible combination.
10381038
10391039 """
1040- # try:
1041- # topology, = set(d.topology for d in self._form.ufl_domains())
1042- # except ValueError:
1043- # raise NotImplementedError("All integration domains must share a mesh topology")
1040+ try :
1041+ topology , = set (d .topology . submesh_ancesters [ - 1 ] for d in self ._form .ufl_domains ())
1042+ except ValueError :
1043+ raise NotImplementedError ("All integration domains must share a mesh topology" )
10441044
1045- #for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046- # domain = extract_unique_domain(o)
1047- # if domain is not None and domain.topology != topology:
1048- # raise NotImplementedError("Assembly with multiple meshes is not supported")
1045+ for o in itertools .chain (self ._form .arguments (), self ._form .coefficients ()):
1046+ domains = extract_domains (o )
1047+ for domain in domains :
1048+ if domain is not None and domain .topology .submesh_ancesters [- 1 ] != topology :
1049+ raise NotImplementedError ("Assembly with multiple meshes is not supported" )
10491050
10501051 if isinstance (self ._form , ufl .Form ):
10511052 kernels = tsfc_interface .compile_form (
@@ -1337,21 +1338,24 @@ def _make_maps_and_regions(self):
13371338 test , trial = self ._form .arguments ()
13381339 if self ._allocation_integral_types is not None :
13391340 return ExplicitMatrixAssembler ._make_maps_and_regions_default (test , trial , self ._allocation_integral_types )
1340- elif any (local_kernel .indices == (None , None ) for local_kernel in self ._all_local_kernels ):
1341+ elif any (local_kernel .indices == (None , None ) for local_kernel , _ in self ._all_local_kernels ):
13411342 # Handle special cases: slate or split=False
1342- assert all (local_kernel .indices == (None , None ) for local_kernel in self ._all_local_kernels )
1343+ assert all (local_kernel .indices == (None , None ) for local_kernel , _ in self ._all_local_kernels )
13431344 allocation_integral_types = set (local_kernel .kinfo .integral_type
1344- for local_kernel in self ._all_local_kernels )
1345+ for local_kernel , _ in self ._all_local_kernels )
13451346 return ExplicitMatrixAssembler ._make_maps_and_regions_default (test , trial , allocation_integral_types )
13461347 else :
13471348 maps_and_regions = defaultdict (lambda : defaultdict (set ))
1348- for local_kernel in self ._all_local_kernels :
1349+ all_meshes = extract_domains (self ._form )
1350+ for local_kernel , subdomain_id in self ._all_local_kernels :
13491351 i , j = local_kernel .indices
1352+ mesh = all_meshes [local_kernel .kinfo .domain_number ] # integration domain
13501353 # Make Sparsity independent of _iterset, which can be a Subset, for better reusability.
13511354 integral_type = local_kernel .kinfo .integral_type
1352- rmap_ = test .function_space ().topological [i ].entity_node_map (integral_type )
1355+ all_subdomain_ids = self .all_integer_subdomain_ids [local_kernel .kinfo .domain_number ]
1356+ rmap_ = test .function_space ().topological [i ].entity_node_map (mesh .topology , integral_type , subdomain_id , all_subdomain_ids )
13531357 # rmap_ = rmap_.split[i] if rmap_ is not None else None
1354- cmap_ = trial .function_space ().topological [j ].entity_node_map (integral_type )
1358+ cmap_ = trial .function_space ().topological [j ].entity_node_map (mesh . topology , integral_type , subdomain_id , all_subdomain_ids )
13551359 # cmap_ = cmap_.split[j] if cmap_ is not None else None
13561360 region = ExplicitMatrixAssembler ._integral_type_region_map [integral_type ]
13571361 maps_and_regions [(i , j )][(rmap_ , cmap_ )].add (region )
@@ -1368,8 +1372,14 @@ def _make_maps_and_regions_default(test, trial, allocation_integral_types):
13681372 # Use outer product of component maps.
13691373 for integral_type in allocation_integral_types :
13701374 region = ExplicitMatrixAssembler ._integral_type_region_map [integral_type ]
1371- for i , rmap_ in enumerate (test .function_space ().topological .entity_node_map (integral_type )):
1372- for j , cmap_ in enumerate (trial .function_space ().topological .entity_node_map (integral_type )):
1375+ #for i, rmap_ in enumerate(test.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1376+ # for j, cmap_ in enumerate(trial.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1377+ # maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
1378+ for i , Vrow in enumerate (test .function_space ()):
1379+ for j , Vcol in enumerate (trial .function_space ()):
1380+ mesh = Vrow .mesh ()
1381+ rmap_ = Vrow .topological .entity_node_map (mesh .topology , integral_type , None , None )
1382+ cmap_ = Vcol .topological .entity_node_map (mesh .topology , integral_type , None , None )
13731383 maps_and_regions [(i , j )][(rmap_ , cmap_ )].add (region )
13741384 return {block_indices : [map_pair + (tuple (region_set ), ) for map_pair , region_set in map_pair_to_region_set .items ()]
13751385 for block_indices , map_pair_to_region_set in maps_and_regions .items ()}
@@ -1391,7 +1401,7 @@ def _all_local_kernels(self):
13911401 When constructing sparsity, we use all parloop_builders
13921402 that are to be used in the actual assembly.
13931403 """
1394- all_local_kernels = tuple ( local_kernel for local_kernel , _ in self .local_kernels )
1404+ all_local_kernels = self .local_kernels
13951405 for bc in self ._bcs :
13961406 if isinstance (bc , EquationBCSplit ):
13971407 _assembler = type (self )(bc .f , bcs = bc .bcs , form_compiler_parameters = self ._form_compiler_params , needs_zeroing = False )
@@ -1561,7 +1571,7 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
15611571 self ._form = form
15621572 self ._indices , self ._kinfo = local_knl
15631573 self ._subdomain_id = subdomain_id
1564- self ._all_integer_subdomain_ids = all_integer_subdomain_ids . get ( self . _kinfo . integral_type , None )
1574+ self ._all_integer_subdomain_ids = all_integer_subdomain_ids
15651575 self ._diagonal = diagonal
15661576 self ._unroll = unroll
15671577
@@ -1628,7 +1638,7 @@ def _needs_subset(self):
16281638 if self ._subdomain_id == "everywhere" :
16291639 return False
16301640 elif self ._subdomain_id == "otherwise" :
1631- return self ._all_integer_subdomain_ids is not None
1641+ return self ._all_integer_subdomain_ids . get ( self . _kinfo . integral_type , None ) is not None
16321642 else :
16331643 return True
16341644
@@ -1648,7 +1658,7 @@ def _get_dim(self, finat_element):
16481658
16491659 def _make_dat_global_kernel_arg (self , V , index = None ):
16501660 finat_element = create_element (V .ufl_element ())
1651- map_arg = V .topological .entity_node_map (self ._integral_type )._global_kernel_arg
1661+ map_arg = V .topological .entity_node_map (self ._mesh . topology , self . _integral_type , self . _subdomain_id , self . _all_integer_subdomain_ids )._global_kernel_arg
16521662 if isinstance (finat_element , finat .EnrichedElement ) and finat_element .is_mixed :
16531663 assert index is None
16541664 subargs = tuple (self ._make_dat_global_kernel_arg (Vsub , index = index )
@@ -1666,7 +1676,7 @@ def _make_mat_global_kernel_arg(self, Vrow, Vcol):
16661676 shape = len (relem .elements ), len (celem .elements )
16671677 return op2 .MixedMatKernelArg (subargs , shape )
16681678 else :
1669- rmap_arg , cmap_arg = (V .topological .entity_node_map (self ._integral_type )._global_kernel_arg for V in [Vrow , Vcol ])
1679+ rmap_arg , cmap_arg = (V .topological .entity_node_map (self ._mesh . topology , self . _integral_type , self . _subdomain_id , self . _all_integer_subdomain_ids )._global_kernel_arg for V in [Vrow , Vcol ])
16701680 # PyOP2 matrix objects have scalar dims so we flatten them here
16711681 rdim = numpy .prod (self ._get_dim (relem ), dtype = int )
16721682 cdim = numpy .prod (self ._get_dim (celem ), dtype = int )
@@ -1767,14 +1777,24 @@ def _as_global_kernel_arg_constant(_, self):
17671777
17681778@_as_global_kernel_arg .register (kernel_args .ExteriorFacetKernelArg )
17691779def _as_global_kernel_arg_exterior_facet (_ , self ):
1770- _ = next (self ._active_exterior_facets )
1771- return op2 .DatKernelArg ((1 ,))
1780+ mesh , _ = next (self ._active_exterior_facets )
1781+ if mesh is self ._mesh :
1782+ return op2 .DatKernelArg ((1 ,))
1783+ else :
1784+ m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
1785+ assert integral_type == "exterior_facet"
1786+ return op2 .DatKernelArg ((1 ,), m ._global_kernel_arg )
17721787
17731788
17741789@_as_global_kernel_arg .register (kernel_args .InteriorFacetKernelArg )
17751790def _as_global_kernel_arg_interior_facet (_ , self ):
1776- _ = next (self ._active_interior_facets )
1777- return op2 .DatKernelArg ((2 ,))
1791+ mesh , _ = next (self ._active_interior_facets )
1792+ if mesh is self ._mesh :
1793+ return op2 .DatKernelArg ((2 ,))
1794+ else :
1795+ m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
1796+ assert integral_type == "interior_facet"
1797+ return op2 .DatKernelArg ((2 ,), m ._global_kernel_arg )
17781798
17791799
17801800@_as_global_kernel_arg .register (CellFacetKernelArg )
@@ -1980,7 +2000,7 @@ def _iterset(self):
19802000 def _get_map (self , V ):
19812001 """Return the appropriate PyOP2 map for a given function space."""
19822002 assert isinstance (V , (WithGeometry , FiredrakeDualSpace , FunctionSpace ))
1983- return V .entity_node_map (self ._integral_type )
2003+ return V .topological . entity_node_map (self ._mesh . topology , self . _integral_type , self . _subdomain_id , self . _all_integer_subdomain_ids )
19842004
19852005 def _as_parloop_arg (self , tsfc_arg ):
19862006 """Return a :class:`op2.ParloopArg` corresponding to the provided
@@ -2068,7 +2088,7 @@ def _as_parloop_arg_exterior_facet(_, self):
20682088 m = None
20692089 else :
20702090 m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
2071- assert integral_type == "exterior_facets "
2091+ assert integral_type == "exterior_facet "
20722092 return op2 .DatParloopArg (local_facet_dat , m )
20732093
20742094
@@ -2079,7 +2099,7 @@ def _as_parloop_arg_interior_facet(_, self):
20792099 m = None
20802100 else :
20812101 m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
2082- assert integral_type == "interior_facets "
2102+ assert integral_type == "interior_facet "
20832103 return op2 .DatParloopArg (local_facet_dat , m )
20842104
20852105
0 commit comments