Skip to content

Commit 924c9af

Browse files
committed
Support tensor kwarg
1 parent 8919f38 commit 924c9af

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

firedrake/assemble.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,11 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
563563
# Symbolic indirection for the input expression
564564
expression = firedrake.Argument(V.sub(index), number=A.number(), part=A.part())
565565
# Symbolic indirection for the output tensor
566+
if tensor is not None:
567+
assert tensor.function_space() == V.dual()
568+
tensor.zero()
566569
parent_tensor = tensor or firedrake.Cofunction(V.dual())
567-
tensor = parent_tensor.sub(index)
570+
tensor = parent_tensor.subfunctions[index]
568571

569572
# If argument numbers have been swapped => Adjoint.
570573
is_adjoint = (arg_expression and arg_expression[0].number() == 0)

tests/firedrake/regression/test_interp_dual.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,24 @@ def test_solve_interp_u(mesh):
279279
assert np.allclose(u.dat.data, u2.dat.data)
280280

281281

282-
def test_interp_subfunction(mesh):
282+
@pytest.mark.parametrize("tensor", (False, True), ids=("tensor=none", "tensor"))
283+
def test_interp_subfunction(mesh, tensor):
283284
V = FunctionSpace(mesh, "DG", 0)
284285
v = TestFunction(V)
285286
Fv = inner(1, v)*dx
286287

287-
W = V*V
288+
W = V * V
288289
w = TestFunction(W)
289-
Fw = inner(1, w[0])*dx
290+
Fw = inner(1, w[1])*dx
291+
expected = assemble(Fw)
290292

291-
I = Interpolate(w[0], Fv)
293+
IFv = Interpolate(w[1], Fv)
292294

293-
c0 = assemble(I)
294-
c1 = assemble(Fw)
295-
assert np.allclose(c0.dat.data_ro, c1.dat.data_ro)
295+
if tensor:
296+
tensor = Cofunction(W.dual())
297+
else:
298+
tensor = None
299+
300+
c = assemble(IFv, tensor=tensor)
301+
assert c.function_space() == W.dual()
302+
assert np.allclose(c.dat.data_ro, expected.dat.data_ro)

0 commit comments

Comments
 (0)