Skip to content

Commit 2e34465

Browse files
committed
Handle tensor kwarg
1 parent 8919f38 commit 2e34465

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

firedrake/assemble.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
549549
# 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint
550550
# This can be generalized to the case where the first slot is an arbitray expression.
551551
rank = len(expr.arguments())
552-
arg_expression = ufl.algorithms.extract_arguments(expression)
553552

554-
# Handle interpolation of subfunctions
553+
# Handle interpolation onto subfunctions
555554
parent_tensor = None
556555
if isinstance(expression, ufl.classes.Indexed):
557556
assert rank == 1
@@ -563,10 +562,14 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
563562
# Symbolic indirection for the input expression
564563
expression = firedrake.Argument(V.sub(index), number=A.number(), part=A.part())
565564
# Symbolic indirection for the output tensor
565+
if tensor is not None:
566+
assert tensor.function_space() == V.dual()
567+
tensor.zero()
566568
parent_tensor = tensor or firedrake.Cofunction(V.dual())
567-
tensor = parent_tensor.sub(index)
569+
tensor = parent_tensor.subfunctions[index]
568570

569571
# If argument numbers have been swapped => Adjoint.
572+
arg_expression = ufl.algorithms.extract_arguments(expression)
570573
is_adjoint = (arg_expression and arg_expression[0].number() == 0)
571574
# Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument.
572575
if not is_adjoint and rank != 1:
@@ -575,7 +578,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
575578
# Get the interpolator
576579
interp_data = expr.interp_data
577580
default_missing_val = interp_data.pop('default_missing_val', None)
578-
579581
interpolator = firedrake.Interpolator(expression, expr.function_space(), **interp_data)
580582
# Assembly
581583
if rank == 1:

tests/firedrake/regression/test_interp_dual.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,24 @@ def test_solve_interp_u(mesh):
279279
assert np.allclose(u.dat.data, u2.dat.data)
280280

281281

282-
def test_interp_subfunction(mesh):
282+
@pytest.mark.parametrize("tensor", (False, True), ids=("tensor=none", "tensor"))
283+
def test_interp_subfunction(mesh, tensor):
283284
V = FunctionSpace(mesh, "DG", 0)
284285
v = TestFunction(V)
285286
Fv = inner(1, v)*dx
286287

287-
W = V*V
288+
W = V * V
288289
w = TestFunction(W)
289-
Fw = inner(1, w[0])*dx
290+
Fw = inner(1, w[1])*dx
291+
expected = assemble(Fw)
290292

291-
I = Interpolate(w[0], Fv)
293+
IFv = Interpolate(w[1], Fv)
292294

293-
c0 = assemble(I)
294-
c1 = assemble(Fw)
295-
assert np.allclose(c0.dat.data_ro, c1.dat.data_ro)
295+
if tensor:
296+
tensor = Cofunction(W.dual())
297+
else:
298+
tensor = None
299+
300+
c = assemble(IFv, tensor=tensor)
301+
assert c.function_space() == W.dual()
302+
assert np.allclose(c.dat.data_ro, expected.dat.data_ro)

0 commit comments

Comments
 (0)