33
44import numpy as np
55from pyadjoint .tape import annotate_tape
6+ from pyop2 import op2
67from pyop2 .utils import cached_property
78import pytools
89import finat .ufl
1213from ufl .corealg .multifunction import MultiFunction
1314from ufl .domain import extract_unique_domain
1415
16+ from firedrake .cofunction import Cofunction
1517from firedrake .constant import Constant
1618from firedrake .function import Function
1719from firedrake .petsc import PETSc
1820from firedrake .utils import ScalarType , split_by
1921
22+ from mpi4py import MPI
23+
2024
2125def _isconstant (expr ):
2226 return isinstance (expr , Constant ) or \
23- (isinstance (expr , Function ) and expr .ufl_element ().family () == "Real" )
27+ (isinstance (expr , ( Function , Cofunction ) ) and expr .ufl_element ().family () == "Real" )
2428
2529
2630def _isfunction (expr ):
27- return isinstance (expr , Function ) and expr .ufl_element ().family () != "Real"
31+ return isinstance (expr , ( Function , Cofunction ) ) and expr .ufl_element ().family () != "Real"
2832
2933
3034class CoefficientCollector (MultiFunction ):
@@ -99,6 +103,9 @@ def component_tensor(self, o, a, _):
99103 def coefficient (self , o ):
100104 return ((o , 1 ),)
101105
106+ def cofunction (self , o ):
107+ return ((o , 1 ),)
108+
102109 def constant_value (self , o ):
103110 return ((o , 1 ),)
104111
@@ -130,34 +137,56 @@ def _as_scalar(self, weighted_coefficients):
130137
131138
132139class Assigner :
133- """Class performing pointwise assignment of an expression to a :class:`firedrake.function.Function`.
140+ """Class performing pointwise assignment of an expression to a function or a cofunction.
141+
142+ Parameters
143+ ----------
144+ assignee : firedrake.function.Function or firedrake.cofunction.Cofunction
145+ Function or Cofunction being assigned to.
146+ expression : ufl.core.expr.Expr or ufl.form.BaseForm
147+ Expression to be assigned.
148+ subset : pyop2.types.set.Set or pyop2.types.set.Subset or pyop2.types.set.MixedSet
149+ Subset to apply the assignment over.
134150
135- :param assignee: The :class:`~.firedrake.function.Function` being assigned to.
136- :param expression: The :class:`ufl.core.expr.Expr` to evaluate.
137- :param subset: Optional subset (:class:`pyop2.types.set.Subset`) to apply the assignment over.
138151 """
139152 symbol = "="
140153
141154 _coefficient_collector = CoefficientCollector ()
142155
143156 def __init__ (self , assignee , expression , subset = None ):
144157 expression = as_ufl (expression )
145-
158+ source_meshes = set ()
146159 for coeff in extract_coefficients (expression ):
147- if isinstance (coeff , Function ) and coeff .ufl_element ().family () != "Real" :
160+ if isinstance (coeff , ( Function , Cofunction ) ) and coeff .ufl_element ().family () != "Real" :
148161 if coeff .ufl_element () != assignee .ufl_element ():
149162 raise ValueError ("All functions in the expression must have the same "
150163 "element as the assignee" )
151- if extract_unique_domain (coeff ) != extract_unique_domain (assignee ):
152- raise ValueError ("All functions in the expression must use the same "
153- "mesh as the assignee" )
154-
155- if (subset and type (assignee .ufl_element ()) == finat .ufl .MixedElement
156- and any (el .family () == "Real"
157- for el in assignee .ufl_element ().sub_elements )):
158- raise ValueError ("Subset is not a valid argument for assigning to a mixed "
159- "element including a real element" )
160-
164+ source_meshes .add (extract_unique_domain (coeff ))
165+ if len (source_meshes ) == 0 :
166+ pass
167+ elif len (source_meshes ) == 1 :
168+ target_mesh = extract_unique_domain (assignee )
169+ source_mesh , = source_meshes
170+ if target_mesh .submesh_youngest_common_ancester (source_mesh ) is None :
171+ raise ValueError (
172+ "All functions in the expression must be defined on a single domain "
173+ "that is in the same submesh family as domain of the assignee"
174+ )
175+ else :
176+ raise ValueError (
177+ "All functions in the expression must be defined on a single domain"
178+ )
179+ if subset is None :
180+ subset = tuple (None for _ in assignee .function_space ())
181+ if len (subset ) != len (assignee .function_space ()):
182+ raise ValueError (f"Provided subset ({ subset } ) incompatible with assignee ({ assignee } )" )
183+ if type (assignee .ufl_element ()) == finat .ufl .MixedElement :
184+ for subs , el in zip (subset , assignee .function_space ().ufl_element ().sub_elements ):
185+ if subs is not None and el .family () == "Real" :
186+ raise ValueError (
187+ "Subset is not a valid argument for assigning to a mixed "
188+ "element including a real element"
189+ )
161190 self ._assignee = assignee
162191 self ._expression = expression
163192 self ._subset = subset
@@ -169,14 +198,21 @@ def __repr__(self):
169198 return f"{ self .__class__ .__name__ } ({ self ._assignee !r} , { self ._expression !r} )"
170199
171200 @PETSc .Log .EventDecorator ()
172- def assign (self ):
173- """Perform the assignment."""
201+ def assign (self , allow_missing_dofs = False ):
202+ """Perform the assignment.
203+
204+ Parameters
205+ ----------
206+ allow_missing_dofs : bool
207+ Permit assignment between objects with mismatching nodes. If `True` then
208+ assignee nodes with no matching assigner nodes are ignored.
209+
210+ """
174211 if annotate_tape ():
175212 raise NotImplementedError (
176213 "Taping with explicit Assigner objects is not supported yet. "
177214 "Use Function.assign instead."
178215 )
179-
180216 # To minimize communication during assignment we perform a number of tricks:
181217 # * If we are not assigning to a subset then we can always write to the
182218 # halo. The validity of the original assignee dat halo does not matter
@@ -191,28 +227,111 @@ def assign(self):
191227 # end up doing a lot of halo exchanges for the expression just to avoid
192228 # a single halo exchange for the assignee.
193229 # * If we do write to the halo then the resulting halo will never be dirty.
194-
195- func_halos_valid = all (f .dat .halo_valid for f in self ._functions )
196- assign_to_halos = (
197- func_halos_valid and (not self ._subset or self ._assignee .dat .halo_valid ))
198-
230+ # If mixed, loop over individual components
231+ for lhs_func , subset , * funcs in zip (self ._assignee .subfunctions , self ._subset , * (f .subfunctions for f in self ._functions )):
232+ target_mesh = extract_unique_domain (lhs_func )
233+ target_V = lhs_func .function_space ()
234+ # Validate / Process subset.
235+ if subset is not None :
236+ if subset is target_V .node_set :
237+ # The whole set.
238+ subset = None
239+ elif subset .superset is target_V .node_set :
240+ # op2.Subset of target_V.node_set
241+ pass
242+ else :
243+ raise ValueError (f"subset ({ subset } ) not a subset of target_V.node_set ({ target_V .node_set } )" )
244+ source_meshes = set (extract_unique_domain (f ) for f in funcs )
245+ if len (source_meshes ) == 0 :
246+ # Assign constants only.
247+ single_mesh_assign = True
248+ elif len (source_meshes ) == 1 :
249+ source_mesh , = source_meshes
250+ if target_mesh is source_mesh :
251+ # Assign (co)functions from one mesh to the same mesh.
252+ single_mesh_assign = True
253+ else :
254+ # Assign (co)functions between a submesh and the parent or between two submeshes.
255+ single_mesh_assign = False
256+ else :
257+ raise ValueError ("All functions in the expression must be defined on a single domain" )
258+ if single_mesh_assign :
259+ self ._assign_single_mesh (lhs_func , subset , funcs , operator )
260+ else :
261+ self ._assign_multi_mesh (lhs_func , subset , funcs , operator , allow_missing_dofs )
262+
263+ def _assign_single_mesh (self , lhs_func , subset , funcs , operator ):
264+ assign_to_halos = all (f .dat .halo_valid for f in funcs ) and (lhs_func .dat .halo_valid or subset is None )
199265 if assign_to_halos :
200- subset_indices = self . _subset . indices if self . _subset else ...
266+ subset_indices = ... if subset is None else subset . indices
201267 data_ro = operator .attrgetter ("data_ro_with_halos" )
202268 else :
203- subset_indices = self . _subset . owned_indices if self . _subset else ...
269+ subset_indices = ... if subset is None else subset . owned_indices
204270 data_ro = operator .attrgetter ("data_ro" )
205-
206- # If mixed, loop over individual components
207- for lhs_dat , * func_dats in zip (self ._assignee .dat .split ,
208- * (f .dat .split for f in self ._functions )):
209- func_data = np .array ([data_ro (f )[subset_indices ] for f in func_dats ])
210- rvalue = self ._compute_rvalue (func_data )
211- self ._assign_single_dat (lhs_dat , subset_indices , rvalue , assign_to_halos )
212-
213- # if we have bothered writing to halo it naturally must not be dirty
271+ func_data = np .array ([data_ro (f .dat )[subset_indices ] for f in funcs ])
272+ rvalue = self ._compute_rvalue (func_data )
273+ self ._assign_single_dat (lhs_func .dat , subset_indices , rvalue , assign_to_halos )
214274 if assign_to_halos :
215- self ._assignee .dat .halo_valid = True
275+ lhs_func .dat .halo_valid = True
276+
277+ def _assign_multi_mesh (self , lhs_func , subset , funcs , operator , allow_missing_dofs ):
278+ target_mesh = extract_unique_domain (lhs_func )
279+ target_V = lhs_func .function_space ()
280+ source_V , = set (f .function_space () for f in funcs )
281+ composed_map = source_V .topological .entity_node_map (target_mesh .topology , "cell" , "everywhere" , None )
282+ indices_active = composed_map .indices_active_with_halo
283+ indices_active_all = indices_active .all ()
284+ indices_active_all = target_mesh .comm .allreduce (indices_active_all , op = MPI .LAND )
285+ if subset is None :
286+ if not indices_active_all and not allow_missing_dofs :
287+ raise ValueError ("Found assignee nodes with no matching assigner nodes: run with `allow_missing_dofs=True`" )
288+ subset_indices_target = target_V .cell_node_map ().values_with_halo [indices_active , :].flatten ()
289+ subset_indices_source = composed_map .values_with_halo [indices_active , :].flatten ()
290+ else :
291+ subset_indices_target , perm , _ = np .intersect1d (
292+ target_V .cell_node_map ().values_with_halo [indices_active , :].flatten (),
293+ subset .indices ,
294+ return_indices = True ,
295+ )
296+ if len (subset .indices ) > len (subset_indices_target ) and not allow_missing_dofs :
297+ raise ValueError ("Found assignee nodes with no matching assigner nodes: run with `allow_missing_dofs=True`" )
298+ subset_indices_source = composed_map .values_with_halo [indices_active , :].flatten ()[perm ]
299+ # Use buffer array to make sure that owned DoFs are updated upon assigning.
300+ # The following example illustrates the issue that a naive assignment would cause.
301+ #
302+ # Consider the following target/source meshes distributed over 2 processes
303+ # with no partition overlap:
304+ #
305+ # 0----0----0----1----1
306+ # | | |
307+ # target 0 0 0 1 1
308+ # (parent mesh) | | |
309+ # 0----0----0----1----1 (owning ranks are shown)
310+ #
311+ # 1----1----1
312+ # | |
313+ # source 1 1 1
314+ # (submesh) | |
315+ # 1----1----1 (owning ranks are shown)
316+ #
317+ # Consider CG1 functions f (on parent) and fsub (on submesh). By a naive
318+ # f.assign(fsub, subset=...), the DoFs shared by rank 0 and rank 1 would
319+ # only be updated on rank 1, which sees those DoFs as ghost, and those
320+ # updated values on rank 1 would be overridden by the old values on rank 0
321+ # upon a halo exchange.
322+ #
323+ # TODO: Use work array for buffer?
324+ buffer = type (lhs_func )(target_V )
325+ finfo = np .finfo (lhs_func .dat .dtype )
326+ buffer .dat ._data [:] = finfo .max
327+ func_data = np .array ([f .dat .data_ro_with_halos [subset_indices_source ] for f in funcs ])
328+ rvalue = self ._compute_rvalue (func_data )
329+ self ._assign_single_dat (buffer .dat , subset_indices_target , rvalue , True )
330+ # Make all owned DoFs up-to-date; ghost DoFs may or may not be up-to-date after this.
331+ buffer .dat .local_to_global_begin (op2 .MIN )
332+ buffer .dat .local_to_global_end (op2 .MIN )
333+ indices = np .where (buffer .dat .data_ro_with_halos < finfo .max * 0.999999999999 )
334+ lhs_func .dat .data_wo_with_halos [indices ] = buffer .dat .data_ro_with_halos [indices ]
216335
217336 @cached_property
218337 def _constants (self ):
0 commit comments