Skip to content

Commit af3b325

Browse files
committed
Merge branch 'main' into pbrubeck/matis
2 parents a2d8f30 + 0528493 commit af3b325

29 files changed

+1831
-396
lines changed

firedrake/assemble.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,7 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
16661666
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)
16671667
self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo)
16681668
self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo)
1669+
self._active_orientations_cell = _FormHandler.iter_active_orientations_cell(form, local_knl.kinfo)
16691670
self._active_orientations_exterior_facet = _FormHandler.iter_active_orientations_exterior_facet(form, local_knl.kinfo)
16701671
self._active_orientations_interior_facet = _FormHandler.iter_active_orientations_interior_facet(form, local_knl.kinfo)
16711672

@@ -1688,6 +1689,7 @@ def build(self):
16881689
assert_empty(self._constants)
16891690
assert_empty(self._active_exterior_facets)
16901691
assert_empty(self._active_interior_facets)
1692+
assert_empty(self._active_orientations_cell)
16911693
assert_empty(self._active_orientations_exterior_facet)
16921694
assert_empty(self._active_orientations_interior_facet)
16931695

@@ -1885,6 +1887,17 @@ def _as_global_kernel_arg_interior_facet(_, self):
18851887
return op2.DatKernelArg((2,), m._global_kernel_arg)
18861888

18871889

1890+
@_as_global_kernel_arg.register(kernel_args.OrientationsCellKernelArg)
1891+
def _(_, self):
1892+
mesh = next(self._active_orientations_cell)
1893+
if mesh is self._mesh:
1894+
return op2.DatKernelArg((1,))
1895+
else:
1896+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
1897+
assert integral_type == "cell"
1898+
return op2.DatKernelArg((1,), m._global_kernel_arg)
1899+
1900+
18881901
@_as_global_kernel_arg.register(kernel_args.OrientationsExteriorFacetKernelArg)
18891902
def _(_, self):
18901903
mesh = next(self._active_orientations_exterior_facet)
@@ -1956,6 +1969,7 @@ def __init__(self, form, bcs, local_knl, subdomain_id,
19561969
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)
19571970
self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo)
19581971
self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo)
1972+
self._active_orientations_cell = _FormHandler.iter_active_orientations_cell(form, local_knl.kinfo)
19591973
self._active_orientations_exterior_facet = _FormHandler.iter_active_orientations_exterior_facet(form, local_knl.kinfo)
19601974
self._active_orientations_interior_facet = _FormHandler.iter_active_orientations_interior_facet(form, local_knl.kinfo)
19611975

@@ -2223,6 +2237,17 @@ def _as_parloop_arg_interior_facet(_, self):
22232237
return op2.DatParloopArg(mesh.interior_facets.local_facet_dat, m)
22242238

22252239

2240+
@_as_parloop_arg.register(kernel_args.OrientationsCellKernelArg)
2241+
def _(_, self):
2242+
mesh = next(self._active_orientations_cell)
2243+
if mesh is self._mesh:
2244+
m = None
2245+
else:
2246+
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
2247+
assert integral_type == "cell"
2248+
return op2.DatParloopArg(mesh.local_cell_orientation_dat, m)
2249+
2250+
22262251
@_as_parloop_arg.register(kernel_args.OrientationsExteriorFacetKernelArg)
22272252
def _(_, self):
22282253
mesh = next(self._active_orientations_exterior_facet)
@@ -2319,6 +2344,14 @@ def iter_active_interior_facets(form, kinfo):
23192344
mesh = all_meshes[i]
23202345
yield mesh
23212346

2347+
@staticmethod
2348+
def iter_active_orientations_cell(form, kinfo):
2349+
"""Yield the form cell orientations referenced in ``kinfo``."""
2350+
all_meshes = extract_domains(form)
2351+
for i in kinfo.active_domain_numbers.orientations_cell:
2352+
mesh = all_meshes[i]
2353+
yield mesh
2354+
23222355
@staticmethod
23232356
def iter_active_orientations_exterior_facet(form, kinfo):
23242357
"""Yield the form exterior facet orientations referenced in ``kinfo``."""

firedrake/assign.py

Lines changed: 157 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from pyadjoint.tape import annotate_tape
6+
from pyop2 import op2
67
from pyop2.utils import cached_property
78
import pytools
89
import finat.ufl
@@ -12,19 +13,22 @@
1213
from ufl.corealg.multifunction import MultiFunction
1314
from ufl.domain import extract_unique_domain
1415

16+
from firedrake.cofunction import Cofunction
1517
from firedrake.constant import Constant
1618
from firedrake.function import Function
1719
from firedrake.petsc import PETSc
1820
from firedrake.utils import ScalarType, split_by
1921

22+
from mpi4py import MPI
23+
2024

2125
def _isconstant(expr):
2226
return isinstance(expr, Constant) or \
23-
(isinstance(expr, Function) and expr.ufl_element().family() == "Real")
27+
(isinstance(expr, (Function, Cofunction)) and expr.ufl_element().family() == "Real")
2428

2529

2630
def _isfunction(expr):
27-
return isinstance(expr, Function) and expr.ufl_element().family() != "Real"
31+
return isinstance(expr, (Function, Cofunction)) and expr.ufl_element().family() != "Real"
2832

2933

3034
class CoefficientCollector(MultiFunction):
@@ -99,6 +103,9 @@ def component_tensor(self, o, a, _):
99103
def coefficient(self, o):
100104
return ((o, 1),)
101105

106+
def cofunction(self, o):
107+
return ((o, 1),)
108+
102109
def constant_value(self, o):
103110
return ((o, 1),)
104111

@@ -130,34 +137,56 @@ def _as_scalar(self, weighted_coefficients):
130137

131138

132139
class Assigner:
133-
"""Class performing pointwise assignment of an expression to a :class:`firedrake.function.Function`.
140+
"""Class performing pointwise assignment of an expression to a function or a cofunction.
141+
142+
Parameters
143+
----------
144+
assignee : firedrake.function.Function or firedrake.cofunction.Cofunction
145+
Function or Cofunction being assigned to.
146+
expression : ufl.core.expr.Expr or ufl.form.BaseForm
147+
Expression to be assigned.
148+
subset : pyop2.types.set.Set or pyop2.types.set.Subset or pyop2.types.set.MixedSet
149+
Subset to apply the assignment over.
134150
135-
:param assignee: The :class:`~.firedrake.function.Function` being assigned to.
136-
:param expression: The :class:`ufl.core.expr.Expr` to evaluate.
137-
:param subset: Optional subset (:class:`pyop2.types.set.Subset`) to apply the assignment over.
138151
"""
139152
symbol = "="
140153

141154
_coefficient_collector = CoefficientCollector()
142155

143156
def __init__(self, assignee, expression, subset=None):
144157
expression = as_ufl(expression)
145-
158+
source_meshes = set()
146159
for coeff in extract_coefficients(expression):
147-
if isinstance(coeff, Function) and coeff.ufl_element().family() != "Real":
160+
if isinstance(coeff, (Function, Cofunction)) and coeff.ufl_element().family() != "Real":
148161
if coeff.ufl_element() != assignee.ufl_element():
149162
raise ValueError("All functions in the expression must have the same "
150163
"element as the assignee")
151-
if extract_unique_domain(coeff) != extract_unique_domain(assignee):
152-
raise ValueError("All functions in the expression must use the same "
153-
"mesh as the assignee")
154-
155-
if (subset and type(assignee.ufl_element()) == finat.ufl.MixedElement
156-
and any(el.family() == "Real"
157-
for el in assignee.ufl_element().sub_elements)):
158-
raise ValueError("Subset is not a valid argument for assigning to a mixed "
159-
"element including a real element")
160-
164+
source_meshes.add(extract_unique_domain(coeff))
165+
if len(source_meshes) == 0:
166+
pass
167+
elif len(source_meshes) == 1:
168+
target_mesh = extract_unique_domain(assignee)
169+
source_mesh, = source_meshes
170+
if target_mesh.submesh_youngest_common_ancester(source_mesh) is None:
171+
raise ValueError(
172+
"All functions in the expression must be defined on a single domain "
173+
"that is in the same submesh family as domain of the assignee"
174+
)
175+
else:
176+
raise ValueError(
177+
"All functions in the expression must be defined on a single domain"
178+
)
179+
if subset is None:
180+
subset = tuple(None for _ in assignee.function_space())
181+
if len(subset) != len(assignee.function_space()):
182+
raise ValueError(f"Provided subset ({subset}) incompatible with assignee ({assignee})")
183+
if type(assignee.ufl_element()) == finat.ufl.MixedElement:
184+
for subs, el in zip(subset, assignee.function_space().ufl_element().sub_elements):
185+
if subs is not None and el.family() == "Real":
186+
raise ValueError(
187+
"Subset is not a valid argument for assigning to a mixed "
188+
"element including a real element"
189+
)
161190
self._assignee = assignee
162191
self._expression = expression
163192
self._subset = subset
@@ -169,14 +198,21 @@ def __repr__(self):
169198
return f"{self.__class__.__name__}({self._assignee!r}, {self._expression!r})"
170199

171200
@PETSc.Log.EventDecorator()
172-
def assign(self):
173-
"""Perform the assignment."""
201+
def assign(self, allow_missing_dofs=False):
202+
"""Perform the assignment.
203+
204+
Parameters
205+
----------
206+
allow_missing_dofs : bool
207+
Permit assignment between objects with mismatching nodes. If `True` then
208+
assignee nodes with no matching assigner nodes are ignored.
209+
210+
"""
174211
if annotate_tape():
175212
raise NotImplementedError(
176213
"Taping with explicit Assigner objects is not supported yet. "
177214
"Use Function.assign instead."
178215
)
179-
180216
# To minimize communication during assignment we perform a number of tricks:
181217
# * If we are not assigning to a subset then we can always write to the
182218
# halo. The validity of the original assignee dat halo does not matter
@@ -191,28 +227,111 @@ def assign(self):
191227
# end up doing a lot of halo exchanges for the expression just to avoid
192228
# a single halo exchange for the assignee.
193229
# * If we do write to the halo then the resulting halo will never be dirty.
194-
195-
func_halos_valid = all(f.dat.halo_valid for f in self._functions)
196-
assign_to_halos = (
197-
func_halos_valid and (not self._subset or self._assignee.dat.halo_valid))
198-
230+
# If mixed, loop over individual components
231+
for lhs_func, subset, *funcs in zip(self._assignee.subfunctions, self._subset, *(f.subfunctions for f in self._functions)):
232+
target_mesh = extract_unique_domain(lhs_func)
233+
target_V = lhs_func.function_space()
234+
# Validate / Process subset.
235+
if subset is not None:
236+
if subset is target_V.node_set:
237+
# The whole set.
238+
subset = None
239+
elif subset.superset is target_V.node_set:
240+
# op2.Subset of target_V.node_set
241+
pass
242+
else:
243+
raise ValueError(f"subset ({subset}) not a subset of target_V.node_set ({target_V.node_set})")
244+
source_meshes = set(extract_unique_domain(f) for f in funcs)
245+
if len(source_meshes) == 0:
246+
# Assign constants only.
247+
single_mesh_assign = True
248+
elif len(source_meshes) == 1:
249+
source_mesh, = source_meshes
250+
if target_mesh is source_mesh:
251+
# Assign (co)functions from one mesh to the same mesh.
252+
single_mesh_assign = True
253+
else:
254+
# Assign (co)functions between a submesh and the parent or between two submeshes.
255+
single_mesh_assign = False
256+
else:
257+
raise ValueError("All functions in the expression must be defined on a single domain")
258+
if single_mesh_assign:
259+
self._assign_single_mesh(lhs_func, subset, funcs, operator)
260+
else:
261+
self._assign_multi_mesh(lhs_func, subset, funcs, operator, allow_missing_dofs)
262+
263+
def _assign_single_mesh(self, lhs_func, subset, funcs, operator):
264+
assign_to_halos = all(f.dat.halo_valid for f in funcs) and (lhs_func.dat.halo_valid or subset is None)
199265
if assign_to_halos:
200-
subset_indices = self._subset.indices if self._subset else ...
266+
subset_indices = ... if subset is None else subset.indices
201267
data_ro = operator.attrgetter("data_ro_with_halos")
202268
else:
203-
subset_indices = self._subset.owned_indices if self._subset else ...
269+
subset_indices = ... if subset is None else subset.owned_indices
204270
data_ro = operator.attrgetter("data_ro")
205-
206-
# If mixed, loop over individual components
207-
for lhs_dat, *func_dats in zip(self._assignee.dat.split,
208-
*(f.dat.split for f in self._functions)):
209-
func_data = np.array([data_ro(f)[subset_indices] for f in func_dats])
210-
rvalue = self._compute_rvalue(func_data)
211-
self._assign_single_dat(lhs_dat, subset_indices, rvalue, assign_to_halos)
212-
213-
# if we have bothered writing to halo it naturally must not be dirty
271+
func_data = np.array([data_ro(f.dat)[subset_indices] for f in funcs])
272+
rvalue = self._compute_rvalue(func_data)
273+
self._assign_single_dat(lhs_func.dat, subset_indices, rvalue, assign_to_halos)
214274
if assign_to_halos:
215-
self._assignee.dat.halo_valid = True
275+
lhs_func.dat.halo_valid = True
276+
277+
def _assign_multi_mesh(self, lhs_func, subset, funcs, operator, allow_missing_dofs):
278+
target_mesh = extract_unique_domain(lhs_func)
279+
target_V = lhs_func.function_space()
280+
source_V, = set(f.function_space() for f in funcs)
281+
composed_map = source_V.topological.entity_node_map(target_mesh.topology, "cell", "everywhere", None)
282+
indices_active = composed_map.indices_active_with_halo
283+
indices_active_all = indices_active.all()
284+
indices_active_all = target_mesh.comm.allreduce(indices_active_all, op=MPI.LAND)
285+
if subset is None:
286+
if not indices_active_all and not allow_missing_dofs:
287+
raise ValueError("Found assignee nodes with no matching assigner nodes: run with `allow_missing_dofs=True`")
288+
subset_indices_target = target_V.cell_node_map().values_with_halo[indices_active, :].flatten()
289+
subset_indices_source = composed_map.values_with_halo[indices_active, :].flatten()
290+
else:
291+
subset_indices_target, perm, _ = np.intersect1d(
292+
target_V.cell_node_map().values_with_halo[indices_active, :].flatten(),
293+
subset.indices,
294+
return_indices=True,
295+
)
296+
if len(subset.indices) > len(subset_indices_target) and not allow_missing_dofs:
297+
raise ValueError("Found assignee nodes with no matching assigner nodes: run with `allow_missing_dofs=True`")
298+
subset_indices_source = composed_map.values_with_halo[indices_active, :].flatten()[perm]
299+
# Use buffer array to make sure that owned DoFs are updated upon assigning.
300+
# The following example illustrates the issue that a naive assignment would cause.
301+
#
302+
# Consider the following target/source meshes distributed over 2 processes
303+
# with no partition overlap:
304+
#
305+
# 0----0----0----1----1
306+
# | | |
307+
# target 0 0 0 1 1
308+
# (parent mesh) | | |
309+
# 0----0----0----1----1 (owning ranks are shown)
310+
#
311+
# 1----1----1
312+
# | |
313+
# source 1 1 1
314+
# (submesh) | |
315+
# 1----1----1 (owning ranks are shown)
316+
#
317+
# Consider CG1 functions f (on parent) and fsub (on submesh). By a naive
318+
# f.assign(fsub, subset=...), the DoFs shared by rank 0 and rank 1 would
319+
# only be updated on rank 1, which sees those DoFs as ghost, and those
320+
# updated values on rank 1 would be overridden by the old values on rank 0
321+
# upon a halo exchange.
322+
#
323+
# TODO: Use work array for buffer?
324+
buffer = type(lhs_func)(target_V)
325+
finfo = np.finfo(lhs_func.dat.dtype)
326+
buffer.dat._data[:] = finfo.max
327+
func_data = np.array([f.dat.data_ro_with_halos[subset_indices_source] for f in funcs])
328+
rvalue = self._compute_rvalue(func_data)
329+
self._assign_single_dat(buffer.dat, subset_indices_target, rvalue, True)
330+
# Make all owned DoFs up-to-date; ghost DoFs may or may not be up-to-date after this.
331+
buffer.dat.local_to_global_begin(op2.MIN)
332+
buffer.dat.local_to_global_end(op2.MIN)
333+
indices = np.where(buffer.dat.data_ro_with_halos < finfo.max * 0.999999999999)
334+
lhs_func.dat.data_wo_with_halos[indices] = buffer.dat.data_ro_with_halos[indices]
216335

217336
@cached_property
218337
def _constants(self):

0 commit comments

Comments
 (0)