1818from firedrake import (extrusion_utils as eutils , matrix , parameters , solving ,
1919 tsfc_interface , utils )
2020from firedrake .adjoint_utils import annotate_assemble
21- from firedrake .ufl_expr import extract_unique_domain
21+ from firedrake .ufl_expr import extract_domains
2222from firedrake .bcs import DirichletBC , EquationBC , EquationBCSplit
2323from firedrake .functionspaceimpl import WithGeometry , FunctionSpace , FiredrakeDualSpace
2424from firedrake .functionspacedata import entity_dofs_key , entity_permutations_key
@@ -1037,15 +1037,16 @@ def local_kernels(self):
10371037 each possible combination.
10381038
10391039 """
1040- # try:
1041- # topology, = set(d.topology for d in self._form.ufl_domains())
1042- # except ValueError:
1043- # raise NotImplementedError("All integration domains must share a mesh topology")
1040+ try :
1041+ topology , = set (d .topology . submesh_ancesters [ - 1 ] for d in self ._form .ufl_domains ())
1042+ except ValueError :
1043+ raise NotImplementedError ("All integration domains must share a mesh topology" )
10441044
1045- #for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1046- # domain = extract_unique_domain(o)
1047- # if domain is not None and domain.topology != topology:
1048- # raise NotImplementedError("Assembly with multiple meshes is not supported")
1045+ for o in itertools .chain (self ._form .arguments (), self ._form .coefficients ()):
1046+ domains = extract_domains (o )
1047+ for domain in domains :
1048+ if domain is not None and domain .topology .submesh_ancesters [- 1 ] != topology :
1049+ raise NotImplementedError ("Assembly with multiple meshes is not supported" )
10491050
10501051 if isinstance (self ._form , ufl .Form ):
10511052 kernels = tsfc_interface .compile_form (
@@ -1337,20 +1338,23 @@ def _make_maps_and_regions(self):
13371338 test , trial = self ._form .arguments ()
13381339 if self ._allocation_integral_types is not None :
13391340 return ExplicitMatrixAssembler ._make_maps_and_regions_default (test , trial , self ._allocation_integral_types )
1340- elif any (local_kernel .indices == (None , None ) for local_kernel in self ._all_local_kernels ):
1341+ elif any (local_kernel .indices == (None , None ) for local_kernel , _ in self ._all_local_kernels ):
13411342 # Handle special cases: slate or split=False
1342- assert all (local_kernel .indices == (None , None ) for local_kernel in self ._all_local_kernels )
1343+ assert all (local_kernel .indices == (None , None ) for local_kernel , _ in self ._all_local_kernels )
13431344 allocation_integral_types = set (local_kernel .kinfo .integral_type
1344- for local_kernel in self ._all_local_kernels )
1345+ for local_kernel , _ in self ._all_local_kernels )
13451346 return ExplicitMatrixAssembler ._make_maps_and_regions_default (test , trial , allocation_integral_types )
13461347 else :
13471348 maps_and_regions = defaultdict (lambda : defaultdict (set ))
1348- for local_kernel in self ._all_local_kernels :
1349+ all_meshes = extract_domains (self ._form )
1350+ for local_kernel , subdomain_id in self ._all_local_kernels :
13491351 i , j = local_kernel .indices
1352+ mesh = all_meshes [local_kernel .kinfo .domain_number ] # integration domain
13501353 # Make Sparsity independent of _iterset, which can be a Subset, for better reusability.
13511354 integral_type = local_kernel .kinfo .integral_type
1352- rmap_ = test .function_space ().topological [i ].entity_node_map (integral_type )
1353- cmap_ = trial .function_space ().topological [j ].entity_node_map (integral_type )
1355+ all_subdomain_ids = self .all_integer_subdomain_ids [local_kernel .kinfo .domain_number ]
1356+ rmap_ = test .function_space ().topological [i ].entity_node_map (mesh .topology , integral_type , subdomain_id , all_subdomain_ids )
1357+ cmap_ = trial .function_space ().topological [j ].entity_node_map (mesh .topology , integral_type , subdomain_id , all_subdomain_ids )
13541358 region = ExplicitMatrixAssembler ._integral_type_region_map [integral_type ]
13551359 maps_and_regions [(i , j )][(rmap_ , cmap_ )].add (region )
13561360 return {block_indices : [map_pair + (tuple (region_set ), ) for map_pair , region_set in map_pair_to_region_set .items ()]
@@ -1366,8 +1370,14 @@ def _make_maps_and_regions_default(test, trial, allocation_integral_types):
13661370 # Use outer product of component maps.
13671371 for integral_type in allocation_integral_types :
13681372 region = ExplicitMatrixAssembler ._integral_type_region_map [integral_type ]
1369- for i , rmap_ in enumerate (test .function_space ().topological .entity_node_map (integral_type )):
1370- for j , cmap_ in enumerate (trial .function_space ().topological .entity_node_map (integral_type )):
1373+ #for i, rmap_ in enumerate(test.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1374+ # for j, cmap_ in enumerate(trial.function_space().topological.entity_node_map(mesh.topology, integral_type, None, None)):
1375+ # maps_and_regions[(i, j)][(rmap_, cmap_)].add(region)
1376+ for i , Vrow in enumerate (test .function_space ()):
1377+ for j , Vcol in enumerate (trial .function_space ()):
1378+ mesh = Vrow .mesh ()
1379+ rmap_ = Vrow .topological .entity_node_map (mesh .topology , integral_type , None , None )
1380+ cmap_ = Vcol .topological .entity_node_map (mesh .topology , integral_type , None , None )
13711381 maps_and_regions [(i , j )][(rmap_ , cmap_ )].add (region )
13721382 return {block_indices : [map_pair + (tuple (region_set ), ) for map_pair , region_set in map_pair_to_region_set .items ()]
13731383 for block_indices , map_pair_to_region_set in maps_and_regions .items ()}
@@ -1389,7 +1399,7 @@ def _all_local_kernels(self):
13891399 When constructing sparsity, we use all parloop_builders
13901400 that are to be used in the actual assembly.
13911401 """
1392- all_local_kernels = tuple ( local_kernel for local_kernel , _ in self .local_kernels )
1402+ all_local_kernels = self .local_kernels
13931403 for bc in self ._bcs :
13941404 if isinstance (bc , EquationBCSplit ):
13951405 _assembler = type (self )(bc .f , bcs = bc .bcs , form_compiler_parameters = self ._form_compiler_params , needs_zeroing = False )
@@ -1559,7 +1569,7 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
15591569 self ._form = form
15601570 self ._indices , self ._kinfo = local_knl
15611571 self ._subdomain_id = subdomain_id
1562- self ._all_integer_subdomain_ids = all_integer_subdomain_ids . get ( self . _kinfo . integral_type , None )
1572+ self ._all_integer_subdomain_ids = all_integer_subdomain_ids
15631573 self ._diagonal = diagonal
15641574 self ._unroll = unroll
15651575
@@ -1626,7 +1636,7 @@ def _needs_subset(self):
16261636 if self ._subdomain_id == "everywhere" :
16271637 return False
16281638 elif self ._subdomain_id == "otherwise" :
1629- return self ._all_integer_subdomain_ids is not None
1639+ return self ._all_integer_subdomain_ids . get ( self . _kinfo . integral_type , None ) is not None
16301640 else :
16311641 return True
16321642
@@ -1646,7 +1656,7 @@ def _get_dim(self, finat_element):
16461656
16471657 def _make_dat_global_kernel_arg (self , V , index = None ):
16481658 finat_element = create_element (V .ufl_element ())
1649- map_arg = V .topological .entity_node_map (self ._integral_type )._global_kernel_arg
1659+ map_arg = V .topological .entity_node_map (self ._mesh . topology , self . _integral_type , self . _subdomain_id , self . _all_integer_subdomain_ids )._global_kernel_arg
16501660 if isinstance (finat_element , finat .EnrichedElement ) and finat_element .is_mixed :
16511661 assert index is None
16521662 subargs = tuple (self ._make_dat_global_kernel_arg (Vsub , index = index )
@@ -1664,7 +1674,7 @@ def _make_mat_global_kernel_arg(self, Vrow, Vcol):
16641674 shape = len (relem .elements ), len (celem .elements )
16651675 return op2 .MixedMatKernelArg (subargs , shape )
16661676 else :
1667- rmap_arg , cmap_arg = (V .topological .entity_node_map (self ._integral_type )._global_kernel_arg for V in [Vrow , Vcol ])
1677+ rmap_arg , cmap_arg = (V .topological .entity_node_map (self ._mesh . topology , self . _integral_type , self . _subdomain_id , self . _all_integer_subdomain_ids )._global_kernel_arg for V in [Vrow , Vcol ])
16681678 # PyOP2 matrix objects have scalar dims so we flatten them here
16691679 rdim = numpy .prod (self ._get_dim (relem ), dtype = int )
16701680 cdim = numpy .prod (self ._get_dim (celem ), dtype = int )
@@ -1765,14 +1775,24 @@ def _as_global_kernel_arg_constant(_, self):
17651775
17661776@_as_global_kernel_arg .register (kernel_args .ExteriorFacetKernelArg )
17671777def _as_global_kernel_arg_exterior_facet (_ , self ):
1768- _ = next (self ._active_exterior_facets )
1769- return op2 .DatKernelArg ((1 ,))
1778+ mesh , _ = next (self ._active_exterior_facets )
1779+ if mesh is self ._mesh :
1780+ return op2 .DatKernelArg ((1 ,))
1781+ else :
1782+ m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
1783+ assert integral_type == "exterior_facet"
1784+ return op2 .DatKernelArg ((1 ,), m ._global_kernel_arg )
17701785
17711786
17721787@_as_global_kernel_arg .register (kernel_args .InteriorFacetKernelArg )
17731788def _as_global_kernel_arg_interior_facet (_ , self ):
1774- _ = next (self ._active_interior_facets )
1775- return op2 .DatKernelArg ((2 ,))
1789+ mesh , _ = next (self ._active_interior_facets )
1790+ if mesh is self ._mesh :
1791+ return op2 .DatKernelArg ((2 ,))
1792+ else :
1793+ m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
1794+ assert integral_type == "interior_facet"
1795+ return op2 .DatKernelArg ((2 ,), m ._global_kernel_arg )
17761796
17771797
17781798@_as_global_kernel_arg .register (CellFacetKernelArg )
@@ -1978,7 +1998,7 @@ def _iterset(self):
19781998 def _get_map (self , V ):
19791999 """Return the appropriate PyOP2 map for a given function space."""
19802000 assert isinstance (V , (WithGeometry , FiredrakeDualSpace , FunctionSpace ))
1981- return V .entity_node_map (self ._integral_type )
2001+ return V .topological . entity_node_map (self ._mesh . topology , self . _integral_type , self . _subdomain_id , self . _all_integer_subdomain_ids )
19822002
19832003 def _as_parloop_arg (self , tsfc_arg ):
19842004 """Return a :class:`op2.ParloopArg` corresponding to the provided
@@ -2066,7 +2086,7 @@ def _as_parloop_arg_exterior_facet(_, self):
20662086 m = None
20672087 else :
20682088 m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
2069- assert integral_type == "exterior_facets "
2089+ assert integral_type == "exterior_facet "
20702090 return op2 .DatParloopArg (local_facet_dat , m )
20712091
20722092
@@ -2077,7 +2097,7 @@ def _as_parloop_arg_interior_facet(_, self):
20772097 m = None
20782098 else :
20792099 m , integral_type = mesh .topology .trans_mesh_entity_map (self ._mesh .topology , self ._integral_type , self ._subdomain_id , self ._all_integer_subdomain_ids )
2080- assert integral_type == "interior_facets "
2100+ assert integral_type == "interior_facet "
20812101 return op2 .DatParloopArg (local_facet_dat , m )
20822102
20832103
0 commit comments