Skip to content

Commit 9b4f147

Browse files
pbrubeckdham
andauthored
Dual interpolation from/into MixedFunctionSpace (#4197)
* Interpolation onto subfunctions * More generic approach using fieldsplit * Interpolate: Fix parloop args ordering * Fix dual * Interpolate from MixedFunctionSpace * Test Interpolate from vector/mixed to mixed * Interpolate from (scalar/mixed) to (scalar/mixed) Co-authored-by: Pablo Brubeck <brubeck@protonmail.com> --------- Co-authored-by: David A. Ham <david.ham@imperial.ac.uk>
1 parent 6238f05 commit 9b4f147

File tree

8 files changed

+229
-77
lines changed

8 files changed

+229
-77
lines changed

firedrake/adjoint_utils/blocks/assembly.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
101101
c = block_variable.output
102102
c_rep = block_variable.saved_output
103103

104-
from ufl.algorithms.analysis import extract_arguments
105-
arity_form = len(extract_arguments(form))
104+
arity_form = len(form.arguments())
106105

107106
if isconstant(c):
108107
mesh = as_domain(self.form)
@@ -157,8 +156,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
157156
hessian_input = hessian_inputs[0]
158157
adj_input = adj_inputs[0]
159158

160-
from ufl.algorithms.analysis import extract_arguments
161-
arity_form = len(extract_arguments(form))
159+
arity_form = len(form.arguments())
162160

163161
c1 = block_variable.output
164162
c1_rep = block_variable.saved_output

firedrake/assemble.py

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import finat.ufl
1919
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
2020
tsfc_interface, utils)
21+
from firedrake.formmanipulation import split_form
2122
from firedrake.adjoint_utils import annotate_assemble
2223
from firedrake.ufl_expr import extract_unique_domain
2324
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
@@ -230,34 +231,36 @@ def assemble(self, tensor=None, current_state=None):
230231
# Only Expr resulting in a Matrix if assembled are BaseFormOperator
231232
if not all(isinstance(op, matrix.AssembledMatrix) for op in (a, b)):
232233
raise TypeError('Mismatching Sum shapes')
233-
return get_assembler(ufl.FormSum((a, 1), (b, 1))).assemble()
234+
return assemble(ufl.FormSum((a, 1), (b, 1)), tensor=tensor)
234235
elif isinstance(expr, ufl.algebra.Product):
235236
a, b = expr.ufl_operands
236237
scalar = [e for e in expr.ufl_operands if is_scalar_constant_expression(e)]
237238
if scalar:
238239
base_form = a if a is scalar else b
239240
assembled_mat = assemble(base_form)
240-
return get_assembler(ufl.FormSum((assembled_mat, scalar[0]))).assemble()
241+
return assemble(ufl.FormSum((assembled_mat, scalar[0])), tensor=tensor)
241242
a, b = [assemble(e) for e in (a, b)]
242-
return get_assembler(ufl.action(a, b)).assemble()
243+
return assemble(ufl.action(a, b), tensor=tensor)
243244
# -- Linear combination of Functions and 1-form BaseFormOperators -- #
244245
# Example: a * u1 + b * u2 + c * N(u1; v*) + d * N(u2; v*)
245246
# with u1, u2 Functions, N a BaseFormOperator, and a, b, c, d scalars or 0-form BaseFormOperators.
246247
else:
247248
base_form_operators = extract_base_form_operators(expr)
248-
assembled_bfops = [firedrake.assemble(e) for e in base_form_operators]
249249
# Substitute base form operators with their output before examining the expression
250250
# which avoids conflict when determining function space, for example:
251251
# extract_coefficients(Interpolate(u, V2)) with u \in V1 will result in an output function space V1
252252
# instead of V2.
253253
if base_form_operators:
254-
expr = ufl.replace(expr, dict(zip(base_form_operators, assembled_bfops)))
255-
try:
256-
coefficients = ufl.algorithms.extract_coefficients(expr)
257-
V, = set(c.function_space() for c in coefficients) - {None}
258-
except ValueError:
259-
raise ValueError("Cannot deduce correct target space from pointwise expression")
260-
return firedrake.Function(V).assign(expr)
254+
assembled_bfops = {e: firedrake.assemble(e) for e in base_form_operators}
255+
expr = ufl.replace(expr, assembled_bfops)
256+
if tensor is None:
257+
try:
258+
coefficients = ufl.algorithms.extract_coefficients(expr)
259+
V, = set(c.function_space() for c in coefficients) - {None}
260+
except ValueError:
261+
raise ValueError("Cannot deduce correct target space from pointwise expression")
262+
tensor = firedrake.Function(V)
263+
return tensor.assign(expr)
261264

262265

263266
class AbstractFormAssembler(abc.ABC):
@@ -493,9 +496,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
493496
if all(isinstance(op, numbers.Complex) for op in args):
494497
result = sum(weight * arg for weight, arg in zip(expr.weights(), args))
495498
return tensor.assign(result) if tensor else result
496-
elif all(isinstance(op, firedrake.Cofunction) for op in args):
499+
elif (all(isinstance(op, firedrake.Cofunction) for op in args)
500+
or all(isinstance(op, firedrake.Function) for op in args)):
497501
V, = set(a.function_space() for a in args)
498-
result = tensor if tensor else firedrake.Cofunction(V)
502+
result = tensor if tensor else firedrake.Function(V)
499503
result.dat.maxpy(expr.weights(), [a.dat for a in args])
500504
return result
501505
elif all(isinstance(op, ufl.Matrix) for op in args):
@@ -540,7 +544,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
540544
if assembled_expression:
541545
# Occur in situations such as Interpolate composition
542546
expression = assembled_expression[0]
543-
expr = expr._ufl_expr_reconstruct_(expression, v)
547+
548+
reconstruct_interp = expr._ufl_expr_reconstruct_
549+
if (v, expression) != expr.argument_slots():
550+
expr = reconstruct_interp(expression, v=v)
544551

545552
# Different assembly procedures:
546553
# 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix)
@@ -552,27 +559,59 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
552559
# If argument numbers have been swapped => Adjoint.
553560
arg_expression = ufl.algorithms.extract_arguments(expression)
554561
is_adjoint = (arg_expression and arg_expression[0].number() == 0)
562+
563+
# Dual interpolation from mixed source
564+
if is_adjoint and len(expr.function_space()) > 1:
565+
cur = 0
566+
sub_expressions = []
567+
components = numpy.reshape(expression, (-1,))
568+
for Vi in expr.function_space():
569+
sub_expressions.append(ufl.as_tensor(components[cur:cur+Vi.value_size].reshape(Vi.value_shape)))
570+
cur += Vi.value_size
571+
572+
# Component-split of the primal expression interpolated into the dual argument-split
573+
split_interp = sum(reconstruct_interp(sub_expressions[i], v=vi) for (i,), vi in split_form(v))
574+
return assemble(split_interp, tensor=tensor)
575+
576+
# Dual interpolation into mixed target
577+
if is_adjoint and len(arg_expression[0].function_space()) > 1 and rank == 1:
578+
V = arg_expression[0].function_space()
579+
tensor = tensor or firedrake.Cofunction(V.dual())
580+
581+
# Argument-split of the Interpolate gets assembled into the corresponding sub-tensor
582+
for (i,), sub_interp in split_form(expr):
583+
assemble(sub_interp, tensor=tensor.subfunctions[i])
584+
return tensor
585+
586+
# Get the primal space
587+
V = expr.function_space()
588+
if is_adjoint or rank == 0:
589+
V = V.dual()
590+
555591
# Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument.
556-
if not is_adjoint and rank != 1:
592+
if not is_adjoint and rank == 2:
557593
_, v1 = expr.arguments()
558-
expression = ufl.replace(expression, {v1: firedrake.Argument(v1.function_space(), number=0, part=v1.part())})
594+
expression = ufl.replace(expression, {v1: v1.reconstruct(number=0)})
559595
# Get the interpolator
560596
interp_data = expr.interp_data
561597
default_missing_val = interp_data.pop('default_missing_val', None)
562-
interpolator = firedrake.Interpolator(expression, expr.function_space(), **interp_data)
598+
interpolator = firedrake.Interpolator(expression, V, **interp_data)
563599
# Assembly
564-
if rank == 1:
600+
if rank == 0:
601+
Iu = interpolator._interpolate(default_missing_val=default_missing_val)
602+
return assemble(ufl.Action(v, Iu), tensor=tensor)
603+
elif rank == 1:
565604
# Assembling the action of the Jacobian adjoint.
566605
if is_adjoint:
567-
output = tensor or firedrake.Cofunction(arg_expression[0].function_space().dual())
568-
return interpolator._interpolate(v, output=output, adjoint=True, default_missing_val=default_missing_val)
606+
return interpolator._interpolate(v, output=tensor, adjoint=True, default_missing_val=default_missing_val)
569607
# Assembling the Jacobian action.
570-
if interpolator.nargs:
608+
elif interpolator.nargs:
571609
return interpolator._interpolate(expression, output=tensor, default_missing_val=default_missing_val)
572610
# Assembling the operator
573-
if tensor is None:
611+
elif tensor is None:
574612
return interpolator._interpolate(default_missing_val=default_missing_val)
575-
return firedrake.Interpolator(expression, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val)
613+
else:
614+
return firedrake.Interpolator(expression, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val)
576615
elif rank == 2:
577616
res = tensor.petscmat if tensor else PETSc.Mat()
578617
# Get the interpolation matrix
@@ -799,7 +838,7 @@ def restructure_base_form(expr, visited=None):
799838
replace_map = {arg: left}
800839
# Decrease number for all the other arguments since the lowest numbered argument will be replaced.
801840
other_args = [a for a in right.arguments() if a is not arg]
802-
new_args = [firedrake.Argument(a.function_space(), number=a.number()-1, part=a.part()) for a in other_args]
841+
new_args = [a.reconstruct(number=a.number()-1) for a in other_args]
803842
replace_map.update(dict(zip(other_args, new_args)))
804843
# Replace arguments
805844
return ufl.replace(right, replace_map)
@@ -810,13 +849,13 @@ def restructure_base_form(expr, visited=None):
810849
u, v = B.arguments()
811850
# Let V1 and V2 be primal spaces, B: V1 -> V2 and B*: V2* -> V1*:
812851
# Adjoint(B(Argument(V1, 1), Argument(V2.dual(), 0))) = B(Argument(V1, 0), Argument(V2.dual(), 1))
813-
reordered_arguments = (firedrake.Argument(u.function_space(), number=v.number(), part=v.part()),
814-
firedrake.Argument(v.function_space(), number=u.number(), part=u.part()))
852+
reordered_arguments = {u: u.reconstruct(number=v.number()),
853+
v: v.reconstruct(number=u.number())}
815854
# Replace arguments in argument slots
816-
return ufl.replace(B, dict(zip((u, v), reordered_arguments)))
855+
return ufl.replace(B, reordered_arguments)
817856

818857
# -- Case (5) -- #
819-
if isinstance(expr, ufl.core.base_form_operator.BaseFormOperator) and not expr.arguments():
858+
if isinstance(expr, ufl.core.base_form_operator.BaseFormOperator) and len(expr.arguments()) == 0:
820859
# We are assembling a BaseFormOperator of rank 0 (no arguments).
821860
# B(f, u*) be a BaseFormOperator with u* a Cofunction and f a Coefficient, then:
822861
# B(f, u*) <=> Action(B(f, v*), f) where v* is a Coargument

firedrake/interpolation.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,11 @@ def __init__(self, expr, v,
8585
the :meth:`interpolate` method or (b) set to zero.
8686
Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`.
8787
"""
88-
8988
# Check function space
9089
if isinstance(v, functionspaceimpl.WithGeometry):
91-
v = Argument(v.dual(), 0)
92-
93-
# Get the primal space (V** = V)
94-
vv = v if not isinstance(v, ufl.Form) else v.arguments()[0]
95-
self._function_space = vv.function_space().dual()
90+
expr_args = extract_arguments(ufl.as_ufl(expr))
91+
is_adjoint = len(expr_args) and expr_args[0].number() == 0
92+
v = Argument(v.dual(), 1 if is_adjoint else 0)
9693
super().__init__(expr, v)
9794

9895
# -- Interpolate data (e.g. `subset` or `access`) -- #
@@ -101,8 +98,7 @@ def __init__(self, expr, v,
10198
"allow_missing_dofs": allow_missing_dofs,
10299
"default_missing_val": default_missing_val}
103100

104-
def function_space(self):
105-
return self._function_space
101+
function_space = ufl.Interpolate.ufl_function_space
106102

107103
def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data):
108104
interp_data = interp_data or self.interp_data.copy()
@@ -293,9 +289,7 @@ def __init__(
293289
expr_args = extract_arguments(expr)
294290
if expr_args and expr_args[0].number() == 0:
295291
v, = expr_args
296-
expr = replace(expr, {v: Argument(v.function_space(),
297-
number=1,
298-
part=v.part())})
292+
expr = replace(expr, {v: v.reconstruct(number=1)})
299293
self.expr_renumbered = expr
300294

301295
def _interpolate_future(self, *function, transpose=None, adjoint=False, default_missing_val=None):
@@ -870,7 +864,7 @@ def _interpolate(self, *function, output=None, transpose=None, adjoint=False, **
870864
raise ValueError("The expression had arguments: we therefore need to be given a Function (not an expression) to interpolate!")
871865
if adjoint:
872866
mul = assembled_interpolator.handle.multHermitian
873-
V = self.arguments[0].function_space()
867+
V = self.arguments[0].function_space().dual()
874868
else:
875869
mul = assembled_interpolator.handle.mult
876870
V = self.V
@@ -1123,34 +1117,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11231117
name = kernel.name
11241118
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
11251119
flop_count=kernel.flop_count, events=(kernel.event,))
1120+
11261121
parloop_args = [kernel, cell_set]
11271122

11281123
coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers)
11291124
if needs_external_coords:
11301125
coefficients = [source_mesh.coordinates] + coefficients
11311126

1132-
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
1133-
if target_mesh is not source_mesh:
1134-
# NOTE: TSFC will sometimes drop run-time arguments in generated
1135-
# kernels if they are deemed not-necessary.
1136-
# FIXME: Checking for argument name in the inner kernel to decide
1137-
# whether to add an extra coefficient is a stopgap until
1138-
# compile_expression_dual_evaluation
1139-
# (a) outputs a coefficient map to indicate argument ordering in
1140-
# parloops as `compile_form` does and
1141-
# (b) allows the dual evaluation related coefficients to be supplied to
1142-
# them rather than having to be added post-hoc (likely by
1143-
# replacing `to_element` with a CoFunction/CoArgument as the
1144-
# target `dual` which would contain `dual` related
1145-
# coefficient(s))
1146-
if rt_var_name in [arg.name for arg in kernel.code[name].args]:
1147-
# Add the coordinates of the target mesh quadrature points in the
1148-
# source mesh's reference cell as an extra argument for the inner
1149-
# loop. (With a vertex only mesh this is a single point for each
1150-
# vertex cell.)
1151-
coefficients.append(target_mesh.reference_coordinates)
1152-
1153-
if tensor in set((c.dat for c in coefficients)):
1127+
if any(c.dat == tensor for c in coefficients):
11541128
output = tensor
11551129
tensor = op2.Dat(tensor.dataset)
11561130
if access is not op2.WRITE:
@@ -1196,6 +1170,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11961170
if needs_cell_sizes:
11971171
cs = source_mesh.cell_sizes
11981172
parloop_args.append(cs.dat(op2.READ, cs.cell_node_map()))
1173+
11991174
for coefficient in coefficients:
12001175
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
12011176
coeff_mesh = extract_unique_domain(coefficient)
@@ -1227,6 +1202,30 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
12271202
for const in extract_firedrake_constants(expr):
12281203
parloop_args.append(const.dat(op2.READ))
12291204

1205+
# Finally, add the target mesh reference coordinates if they appear in the kernel
1206+
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
1207+
if target_mesh is not source_mesh:
1208+
# NOTE: TSFC will sometimes drop run-time arguments in generated
1209+
# kernels if they are deemed not-necessary.
1210+
# FIXME: Checking for argument name in the inner kernel to decide
1211+
# whether to add an extra coefficient is a stopgap until
1212+
# compile_expression_dual_evaluation
1213+
# (a) outputs a coefficient map to indicate argument ordering in
1214+
# parloops as `compile_form` does and
1215+
# (b) allows the dual evaluation related coefficients to be supplied to
1216+
# them rather than having to be added post-hoc (likely by
1217+
# replacing `to_element` with a CoFunction/CoArgument as the
1218+
# target `dual` which would contain `dual` related
1219+
# coefficient(s))
1220+
if any(arg.name == rt_var_name for arg in kernel.code[name].args):
1221+
# Add the coordinates of the target mesh quadrature points in the
1222+
# source mesh's reference cell as an extra argument for the inner
1223+
# loop. (With a vertex only mesh this is a single point for each
1224+
# vertex cell.)
1225+
target_ref_coords = target_mesh.reference_coordinates
1226+
m_ = target_ref_coords.cell_node_map()
1227+
parloop_args.append(target_ref_coords.dat(op2.READ, m_))
1228+
12301229
parloop = op2.ParLoop(*parloop_args)
12311230
parloop_compute_callable = parloop.compute
12321231
if isinstance(tensor, op2.Mat):

firedrake/ufl_expr.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,13 +336,9 @@ def adjoint(form, reordered_arguments=None, derivatives_expanded=None):
336336
# given. To avoid that, always pass reordered_arguments with
337337
# firedrake.Argument objects.
338338
if reordered_arguments is None:
339-
v, u = extract_arguments(form)
340-
reordered_arguments = (Argument(u.function_space(),
341-
number=v.number(),
342-
part=v.part()),
343-
Argument(v.function_space(),
344-
number=u.number(),
345-
part=u.part()))
339+
v, u = form.arguments()
340+
reordered_arguments = (u.reconstruct(number=v.number()),
341+
v.reconstruct(number=u.number()))
346342
return ufl.adjoint(form, reordered_arguments, derivatives_expanded=derivatives_expanded)
347343

348344

0 commit comments

Comments
 (0)