Skip to content

Commit 0d5a629

Browse files
committed
Interpolate: Fix parloop args ordering
1 parent 6e10ebf commit 0d5a629

File tree

5 files changed

+37
-38
lines changed

5 files changed

+37
-38
lines changed

firedrake/assemble.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
494494
if all(isinstance(op, numbers.Complex) for op in args):
495495
result = sum(weight * arg for weight, arg in zip(expr.weights(), args))
496496
return tensor.assign(result) if tensor else result
497-
elif all(isinstance(op, firedrake.Cofunction) for op in args):
497+
elif (all(isinstance(op, firedrake.Cofunction) for op in args)
498+
or all(isinstance(op, firedrake.Function) for op in args)):
498499
V, = set(a.function_space() for a in args)
499-
result = tensor if tensor else firedrake.Cofunction(V)
500+
result = tensor if tensor else firedrake.Function(V)
500501
result.dat.maxpy(expr.weights(), [a.dat for a in args])
501502
return result
502503
elif all(isinstance(op, ufl.Matrix) for op in args):
@@ -541,7 +542,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
541542
if assembled_expression:
542543
# Occur in situations such as Interpolate composition
543544
expression = assembled_expression[0]
544-
expr = expr._ufl_expr_reconstruct_(expression, v)
545+
if (v, expression) != expr.argument_slots():
546+
expr = expr._ufl_expr_reconstruct_(expression, v=v)
545547

546548
# Different assembly procedures:
547549
# 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix)
@@ -553,8 +555,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
553555
arg_expression = ufl.algorithms.extract_arguments(expression)
554556

555557
# Handle interpolation onto subfunctions
556-
if (arg_expression and len(arg_expression[0].function_space()) > 1
557-
and not isinstance(expression, firedrake.Argument)):
558+
if arg_expression and len(arg_expression[0].function_space()) > 1:
558559
assert rank == 1
559560
V = arg_expression[0].function_space()
560561
if tensor is not None:

firedrake/interpolation.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,9 @@ def __init__(self, expr, v,
8585
the :meth:`interpolate` method or (b) set to zero.
8686
Ignored if interpolating within the same mesh or onto a :func:`.VertexOnlyMesh`.
8787
"""
88-
8988
# Check function space
9089
if isinstance(v, functionspaceimpl.WithGeometry):
9190
v = Argument(v.dual(), 0)
92-
93-
# Get the primal space (V** = V)
94-
vv = v if not isinstance(v, ufl.Form) else v.arguments()[0]
95-
self._function_space = vv.function_space().dual()
9691
super().__init__(expr, v)
9792

9893
# -- Interpolate data (e.g. `subset` or `access`) -- #
@@ -101,8 +96,7 @@ def __init__(self, expr, v,
10196
"allow_missing_dofs": allow_missing_dofs,
10297
"default_missing_val": default_missing_val}
10398

104-
def function_space(self):
105-
return self._function_space
99+
function_space = ufl.Interpolate.ufl_function_space
106100

107101
def _ufl_expr_reconstruct_(self, expr, v=None, **interp_data):
108102
interp_data = interp_data or self.interp_data.copy()
@@ -1123,12 +1117,14 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11231117
name = kernel.name
11241118
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
11251119
flop_count=kernel.flop_count, events=(kernel.event,))
1120+
11261121
parloop_args = [kernel, cell_set]
11271122

11281123
coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers)
11291124
if needs_external_coords:
11301125
coefficients = [source_mesh.coordinates] + coefficients
11311126

1127+
needs_target_ref_coords = False
11321128
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
11331129
if target_mesh is not source_mesh:
11341130
# NOTE: TSFC will sometimes drop run-time arguments in generated
@@ -1143,14 +1139,10 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11431139
# replacing `to_element` with a CoFunction/CoArgument as the
11441140
# target `dual` which would contain `dual` related
11451141
# coefficient(s))
1146-
if rt_var_name in [arg.name for arg in kernel.code[name].args]:
1147-
# Add the coordinates of the target mesh quadrature points in the
1148-
# source mesh's reference cell as an extra argument for the inner
1149-
# loop. (With a vertex only mesh this is a single point for each
1150-
# vertex cell.)
1151-
coefficients.append(target_mesh.reference_coordinates)
1152-
1153-
if tensor in set((c.dat for c in coefficients)):
1142+
if any(arg.name == rt_var_name for arg in kernel.code[name].args):
1143+
needs_target_ref_coords = True
1144+
1145+
if tensor in set(c.dat for c in coefficients):
11541146
output = tensor
11551147
tensor = op2.Dat(tensor.dataset)
11561148
if access is not op2.WRITE:
@@ -1196,6 +1188,7 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
11961188
if needs_cell_sizes:
11971189
cs = target_mesh.cell_sizes
11981190
parloop_args.append(cs.dat(op2.READ, cs.cell_node_map()))
1191+
11991192
for coefficient in coefficients:
12001193
if isinstance(target_mesh.topology, firedrake.mesh.VertexOnlyMeshTopology):
12011194
coeff_mesh = extract_unique_domain(coefficient)
@@ -1227,6 +1220,15 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
12271220
for const in extract_firedrake_constants(expr):
12281221
parloop_args.append(const.dat(op2.READ))
12291222

1223+
if needs_target_ref_coords:
1224+
# Add the coordinates of the target mesh quadrature points in the
1225+
# source mesh's reference cell as an extra argument for the inner
1226+
# loop. (With a vertex only mesh this is a single point for each
1227+
# vertex cell.)
1228+
target_ref_coords = target_mesh.reference_coordinates
1229+
m_ = target_ref_coords.cell_node_map()
1230+
parloop_args.append(target_ref_coords.dat(op2.READ, m_))
1231+
12301232
parloop = op2.ParLoop(*parloop_args)
12311233
parloop_compute_callable = parloop.compute
12321234
if isinstance(tensor, op2.Mat):

firedrake/ufl_expr.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,13 +336,9 @@ def adjoint(form, reordered_arguments=None, derivatives_expanded=None):
336336
# given. To avoid that, always pass reordered_arguments with
337337
# firedrake.Argument objects.
338338
if reordered_arguments is None:
339-
v, u = extract_arguments(form)
340-
reordered_arguments = (Argument(u.function_space(),
341-
number=v.number(),
342-
part=v.part()),
343-
Argument(v.function_space(),
344-
number=u.number(),
345-
part=u.part()))
339+
v, u = form.arguments()
340+
reordered_arguments = (u.reconstruct(number=v.number()),
341+
v.reconstruct(number=u.number()))
346342
return ufl.adjoint(form, reordered_arguments, derivatives_expanded=derivatives_expanded)
347343

348344

firedrake/variational_solver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
DEFAULT_SNES_PARAMETERS
1010
)
1111
from firedrake.function import Function
12+
from firedrake.interpolation import Interpolate
1213
from firedrake.matrix import MatrixBase
13-
from firedrake.ufl_expr import TrialFunction, TestFunction, action
14+
from firedrake.ufl_expr import TrialFunction, TestFunction
1415
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
1516
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
16-
from firedrake.__future__ import interpolate
1717
from ufl import replace, Form
1818

1919
__all__ = ["LinearVariationalProblem",
@@ -100,7 +100,7 @@ def __init__(self, F, u, bcs=None, J=None,
100100
F_arg, = F.arguments()
101101
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
102102
else:
103-
self.F = action(replace(F, {self.u: self.u_restrict}), interpolate(v_res, V))
103+
self.F = Interpolate(v_res, replace(F, {self.u: self.u_restrict}))
104104
v_arg, u_arg = self.J.arguments()
105105
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
106106
if self.Jp:

tests/firedrake/regression/test_interp_dual.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_solve_interp_u(mesh):
280280

281281

282282
@pytest.mark.parametrize("family0,degree0,family1,degree1",
283-
[("DG", 1, "CG", 1),
283+
[("DG", 1, "CG", 2),
284284
("DG", 0, "RT", 1)])
285285
def test_interp_subfunction(mesh, family0, degree0, family1, degree1):
286286
V = FunctionSpace(mesh, "DG", 0)
@@ -291,20 +291,20 @@ def test_interp_subfunction(mesh, family0, degree0, family1, degree1):
291291
W1 = FunctionSpace(mesh, family1, degree1)
292292
W = W0 * W1
293293
w = TestFunction(W)
294-
c = Cofunction(W.dual())
295294

296295
expr = sum(w[i] for i in np.ndindex(w.ufl_shape))
297296

298-
Fw = inner(1, expr)*dx
297+
Fw = inner(1, expr)*dx(degree=0)
299298
expected = assemble(Fw)
300299

301300
IFv = Interpolate(expr, Fv)
302301

302+
c = Cofunction(W.dual())
303+
c.assign(99)
303304
for tensor in (None, c):
304-
if tensor:
305-
tensor.assign(99)
306-
307305
result = assemble(IFv, tensor=tensor)
308306
assert result.function_space() == W.dual()
309-
with result.dat.vec_ro as x, expected.dat.vec_ro as y:
310-
assert np.allclose(x[:], y[:])
307+
if tensor:
308+
assert result is tensor
309+
for x, y, in zip(result.subfunctions, expected.subfunctions):
310+
assert np.allclose(x.dat.data_ro, y.dat.data_ro)

0 commit comments

Comments
 (0)