Skip to content

Commit bb02501

Browse files
committed
WIP adjoint test
1 parent e8fd26f commit bb02501

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

firedrake/assemble.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ def restructure_base_form(expr, visited=None):
832832
replace_map = {arg: left}
833833
# Decrease number for all the other arguments since the lowest numbered argument will be replaced.
834834
other_args = [a for a in right.arguments() if a is not arg]
835-
new_args = [firedrake.Argument(a.function_space(), number=a.number()-1, part=a.part()) for a in other_args]
835+
new_args = [a.reconstruct(number=a.number()-1) for a in other_args]
836836
replace_map.update(dict(zip(other_args, new_args)))
837837
# Replace arguments
838838
return ufl.replace(right, replace_map)
@@ -843,8 +843,8 @@ def restructure_base_form(expr, visited=None):
843843
u, v = B.arguments()
844844
# Let V1 and V2 be primal spaces, B: V1 -> V2 and B*: V2* -> V1*:
845845
# Adjoint(B(Argument(V1, 1), Argument(V2.dual(), 0))) = B(Argument(V1, 0), Argument(V2.dual(), 1))
846-
reordered_arguments = (firedrake.Argument(u.function_space(), number=v.number(), part=v.part()),
847-
firedrake.Argument(v.function_space(), number=u.number(), part=u.part()))
846+
reordered_arguments = (u.reconstruct(number=v.number()),
847+
v.reconstruct(number=u.number()))
848848
# Replace arguments in argument slots
849849
return ufl.replace(B, dict(zip((u, v), reordered_arguments)))
850850

tests/firedrake/adjoint/test_reduced_functional.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,21 @@ def test_assemble_recompute():
202202
h = Function(V)
203203
h.vector()[:] = 1
204204
assert taylor_test(Jhat, f, h) > 1.9
205+
206+
207+
@pytest.mark.skipcomplex
208+
def test_interpolate():
209+
mesh = UnitSquareMesh(10, 10)
210+
V = FunctionSpace(mesh, "CG", 1)
211+
Q = FunctionSpace(mesh, "DG", 0)
212+
c = Cofunction(Q.dual())
213+
c.dat.data[:] = 1
214+
215+
f = Function(V)
216+
f.dat.data[:] = 2
217+
J = assemble(Interpolate(f, c))
218+
Jhat = ReducedFunctional(J, Control(f))
219+
220+
h = Function(V)
221+
h.dat.data[:] = 1
222+
assert taylor_test(Jhat, f, h) > 1.9

0 commit comments

Comments
 (0)