Skip to content

Commit cb92d81

Browse files
committed
working
remove
1 parent fc17045 commit cb92d81

File tree

3 files changed

+78
-100
lines changed

3 files changed

+78
-100
lines changed

firedrake/interpolation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,12 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None)
579579
Q_dest = None
580580
target_space = Q_dest or self.target_space
581581

582-
# self.ufl_interpolate.function_space() is None in the 0-form case
583-
f = tensor or Function(self.ufl_interpolate.function_space() or self.target_space)
582+
if into_quadrature_space and not self.ufl_interpolate.is_adjoint:
583+
f = Function(target_space)
584+
else:
585+
# self.ufl_interpolate.function_space() is None in the 0-form case
586+
f = Function(self.ufl_interpolate.function_space() or target_space)
587+
# f = tensor or Function(target_space.dual() if self.ufl_interpolate.is_adjoint else target_space)
584588

585589
point_eval, point_eval_input_ordering = self._get_symbolic_expressions(target_space)
586590
P0DG_vom_input_ordering = point_eval_input_ordering.argument_slots()[0].function_space().dual()

tests/firedrake/regression/test_cross_mesh_non_lagrange.py

Lines changed: 65 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from finat.quadrature import QuadratureRule
55
from functools import partial
66

7-
from ufl.compound_expressions import deviatoric_expr_2x2
8-
97

108
def fs_shape(V):
119
shape = V.ufl_function_space().value_shape
@@ -44,34 +42,8 @@ def V(request):
4442
return FunctionSpace(mesh, element, degree)
4543

4644

47-
def test_cross_mesh_oneform(V):
48-
mesh1 = UnitSquareMesh(5, 5)
49-
mesh2 = V.mesh()
50-
x, y = SpatialCoordinate(mesh1)
51-
x1, y1 = SpatialCoordinate(mesh2)
52-
53-
shape = V.ufl_function_space().value_shape
54-
if len(shape) == 1:
55-
fs_type = partial(VectorFunctionSpace, dim=shape[0])
56-
expr1 = as_vector([x, y])
57-
expr2 = as_vector([x1, y1])
58-
elif len(shape) == 2:
59-
fs_type = partial(TensorFunctionSpace, shape=shape)
60-
expr1 = as_tensor([[x, x*y], [x*y, y]])
61-
expr2 = as_tensor([[x1, x1*y1], [x1*y1, y1]])
62-
else:
63-
raise ValueError("Unsupported target space shape")
64-
65-
V_source = fs_type(mesh1, "CG", 2)
66-
f_source = Function(V_source).interpolate(expr1)
67-
68-
f_target = assemble(interpolate(f_source, V))
69-
f_direct = Function(V).interpolate(expr2)
70-
71-
assert np.allclose(f_target.dat.data_ro, f_direct.dat.data_ro)
72-
73-
74-
def test_cross_mesh_twoform(V):
45+
@pytest.mark.parametrize("rank", [1, 2])
46+
def test_cross_mesh(V, rank):
7547
mesh1 = UnitSquareMesh(5, 5)
7648
mesh2 = V.mesh()
7749
x, y = SpatialCoordinate(mesh1)
@@ -90,17 +62,42 @@ def test_cross_mesh_twoform(V):
9062
raise ValueError("Unsupported target space shape")
9163

9264
V_source = fs_type(mesh1, "CG", 2)
93-
94-
I = assemble(interpolate(TrialFunction(V_source), V)) # V_source x V^* -> R
95-
9665
f_source = Function(V_source).interpolate(expr1)
9766
f_direct = Function(V).interpolate(expr2)
9867

99-
f_interpolated = assemble(action(I, f_source))
100-
assert np.allclose(f_interpolated.dat.data_ro, f_direct.dat.data_ro)
101-
102-
103-
def test_cross_mesh_oneform_adjoint(V):
68+
Q = make_quadrature_space(V)
69+
70+
if rank == 2:
71+
# Assemble the operator
72+
I1 = interpolate(TrialFunction(V_source), Q) # V_source x Q_target^* -> R
73+
I2 = interpolate(TrialFunction(Q), V) # Q_target x V^* -> R
74+
I_manual = assemble(action(I2, I1)) # V_source x V^* -> R
75+
assert I_manual.arguments() == (TestFunction(V.dual()), TrialFunction(V_source))
76+
# Direct assembly
77+
I_direct = assemble(interpolate(TrialFunction(V_source), V)) # V_source
78+
assert I_direct.arguments() == (TestFunction(V.dual()), TrialFunction(V_source))
79+
80+
f_interpolated_manual = assemble(action(I_manual, f_source))
81+
assert np.allclose(f_interpolated_manual.dat.data_ro, f_direct.dat.data_ro)
82+
f_interpolated_direct = assemble(action(I_direct, f_source))
83+
assert np.allclose(f_interpolated_direct.dat.data_ro, f_direct.dat.data_ro)
84+
elif rank == 1:
85+
# Interp V_source -> Q
86+
I1 = interpolate(f_source, Q) # SameMesh
87+
f_quadrature = assemble(I1)
88+
# Interp Q -> V
89+
I2 = interpolate(f_quadrature, V) # CrossMesh
90+
f_interpolated_manual = assemble(I2)
91+
assert f_interpolated_manual.function_space() == V
92+
assert np.allclose(f_interpolated_manual.dat.data_ro, f_direct.dat.data_ro)
93+
94+
f_interpolated_direct = assemble(interpolate(f_source, V))
95+
assert f_interpolated_direct.function_space() == V
96+
assert np.allclose(f_interpolated_direct.dat.data_ro, f_direct.dat.data_ro)
97+
98+
99+
@pytest.mark.parametrize("rank", [1, 2])
100+
def test_cross_mesh_adjoint(V, rank):
104101
# Can already do Lagrange -> RT adjoint
105102
# V^* -> Q^* -> V_target^*
106103
mesh1 = UnitSquareMesh(2, 2)
@@ -114,59 +111,40 @@ def test_cross_mesh_oneform_adjoint(V):
114111
expr = outer(x2, x2)
115112
target_expr = outer(x1, x1)
116113
if V.ufl_element().mapping() == "covariant contravariant Piola":
117-
expr = deviatoric_expr_2x2(expr)
118-
target_expr = deviatoric_expr_2x2(target_expr)
114+
expr = dev(expr)
115+
target_expr = dev(target_expr)
119116
else:
120117
expr = x2
121118
target_expr = x1
122119

123120
oneform_V = inner(expr, TestFunction(V)) * dx
124-
125-
Q_target = make_quadrature_space(V)
126-
127-
# Interp V^* -> Q^*
128-
I1_adj = interpolate(TestFunction(Q_target), oneform_V) # SameMesh
129-
cofunc_Q = assemble(I1_adj)
130-
131-
# Interp Q^* -> V_target^*
132-
I2_adj = interpolate(TestFunction(V_target), cofunc_Q) # CrossMesh
133-
cofunc_Vtarget_manual = assemble(I2_adj)
134-
135-
cofunc_Vtarget = assemble(interpolate(TestFunction(V_target), oneform_V)) # V^* -> V_target^*
136-
assert np.allclose(cofunc_Vtarget_manual.dat.data_ro, cofunc_Vtarget.dat.data_ro)
137-
138121
cofunc_Vtarget_direct = assemble(inner(target_expr, TestFunction(V_target)) * dx)
139-
assert np.allclose(cofunc_Vtarget.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro)
140-
141-
142-
def test_cross_mesh_twoform_adjoint(V):
143-
# V^* -> Q^* -> V_target^*
144-
mesh1 = UnitSquareMesh(2, 2)
145-
x1 = SpatialCoordinate(mesh1)
146-
V_target = fs_shape(V)(mesh1, "CG", 1)
147-
mesh2 = V.mesh()
148-
x2 = SpatialCoordinate(mesh2)
149-
150-
if len(V.value_shape) > 1:
151-
expr = outer(x2, x2)
152-
target_expr = outer(x1, x1)
153-
if V.ufl_element().mapping() == "covariant contravariant Piola":
154-
expr = deviatoric_expr_2x2(expr)
155-
target_expr = deviatoric_expr_2x2(target_expr)
156-
else:
157-
expr = x2
158-
target_expr = x1
159-
160-
oneform_V = inner(expr, TestFunction(V)) * dx
161-
162-
I = assemble(interpolate(TestFunction(V_target), V)) # V^* x V_target -> R
163-
assert I.arguments() == (TestFunction(V_target), TrialFunction(V.dual()))
164-
165-
cofunc_V = assemble(action(I, oneform_V))
166-
cofunc_V_direct = assemble(inner(target_expr, TestFunction(V_target)) * dx)
167-
168-
assert np.allclose(cofunc_V.dat.data_ro, cofunc_V_direct.dat.data_ro)
169-
170122

171-
if __name__ == "__main__":
172-
pytest.main([__file__ + "::test_cross_mesh_oneform_adjoint[RT_2]"])
123+
Q = make_quadrature_space(V)
124+
125+
if rank == 2:
126+
# Assemble the operator
127+
I1 = interpolate(TestFunction(Q), V) # V^* x Q -> R
128+
I2 = interpolate(TestFunction(V_target), Q) # Q^* x V_target -> R
129+
I_manual = assemble(action(I2, I1)) # V^* x V_target -> R
130+
assert I_manual.arguments() == (TestFunction(V_target), TrialFunction(V.dual()))
131+
# Direct assembly
132+
I_direct = assemble(interpolate(TestFunction(V_target), V)) # V^* x V_target -> R
133+
assert I_direct.arguments() == (TestFunction(V_target), TrialFunction(V.dual()))
134+
135+
cofunc_Vtarget_manual = assemble(action(I_manual, oneform_V))
136+
assert np.allclose(cofunc_Vtarget_manual.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro)
137+
cofunc_Vtarget = assemble(action(I_direct, oneform_V))
138+
assert np.allclose(cofunc_Vtarget.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro)
139+
elif rank == 1:
140+
# Interp V^* -> Q^*
141+
I1_adj = interpolate(TestFunction(Q), oneform_V) # SameMesh
142+
cofunc_Q = assemble(I1_adj)
143+
144+
# Interp Q^* -> V_target^*
145+
I2_adj = interpolate(TestFunction(V_target), cofunc_Q) # CrossMesh
146+
cofunc_Vtarget_manual = assemble(I2_adj)
147+
assert np.allclose(cofunc_Vtarget_manual.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro)
148+
149+
cofunc_Vtarget = assemble(interpolate(TestFunction(V_target), oneform_V)) # V^* -> V_target^*
150+
assert np.allclose(cofunc_Vtarget.dat.data_ro, cofunc_Vtarget_direct.dat.data_ro)

tests/firedrake/regression/test_interpolate_cross_mesh.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def allgather(comm, coords):
1515
return coords
1616

1717

18-
def unitsquaresetup():
18+
def unitsquaresetup(dest_quad=True):
1919
m_src = UnitSquareMesh(2, 3)
20-
m_dest = UnitSquareMesh(3, 5, quadrilateral=True)
20+
m_dest = UnitSquareMesh(3, 5, quadrilateral=dest_quad)
2121
coords = np.array(
2222
[[0.56, 0.6], [0.1, 0.9], [0.9, 0.1], [0.9, 0.9], [0.726, 0.6584]]
2323
) # fairly arbitrary
@@ -38,7 +38,7 @@ def make_high_order(m_low_order, degree):
3838

3939
@pytest.fixture(
4040
params=[
41-
"unitsquare",
41+
"unitsquare_RT_N1curl_destination",
4242
"circlemanifold",
4343
"circlemanifold_to_high_order",
4444
"unitsquare_from_high_order",
@@ -47,7 +47,7 @@ def make_high_order(m_low_order, degree):
4747
"unitsquare_vfs",
4848
"unitsquare_tfs",
4949
"unitsquare_N1curl_source",
50-
"unitsquare_SminusDiv_destination",
50+
"unitsquare_RT_N1curl_destination",
5151
"unitsquare_Regge_source",
5252
# This test fails in complex mode
5353
pytest.param("spheresphere", marks=pytest.mark.skipcomplex),
@@ -179,14 +179,14 @@ def parameters(request):
179179
V_src = FunctionSpace(m_src, "N1curl", 2) # Not point evaluation nodes
180180
V_dest = VectorFunctionSpace(m_dest, "CG", 4)
181181
V_dest_2 = VectorFunctionSpace(m_dest, "DQ", 2)
182-
elif request.param == "unitsquare_SminusDiv_destination":
183-
m_src, m_dest, coords = unitsquaresetup()
182+
elif request.param == "unitsquare_RT_N1curl_destination":
183+
m_src, m_dest, coords = unitsquaresetup(dest_quad=False)
184184
expr_src = 2 * SpatialCoordinate(m_src)
185185
expr_dest = 2 * SpatialCoordinate(m_dest)
186186
expected = 2 * coords
187187
V_src = VectorFunctionSpace(m_src, "CG", 2)
188188
V_dest = FunctionSpace(m_dest, "RT", 2) # Not point evaluation nodes
189-
V_dest_2 = FunctionSpace(m_dest, "N1Curl", 2) # Not point evaluation nodes
189+
V_dest_2 = FunctionSpace(m_dest, "N1curl", 2) # Not point evaluation nodes
190190
elif request.param == "unitsquare_Regge_source":
191191
m_src, m_dest, coords = unitsquaresetup()
192192
expr_src = outer(SpatialCoordinate(m_src), SpatialCoordinate(m_src))
@@ -743,7 +743,3 @@ def test_interpolate_cross_mesh_interval(periodic):
743743
f_dest = Function(V_dest).interpolate(f_src)
744744
x_dest, = SpatialCoordinate(m_dest)
745745
assert abs(assemble((f_dest - (-(x_dest - .5) ** 2)) ** 2 * dx)) < 1.e-16
746-
747-
748-
if __name__ == "__main__":
749-
pytest.main([__file__ + "::test_interpolate_matrix_cross_mesh_adjoint"])

0 commit comments

Comments
 (0)