Skip to content

Commit 416fcde

Browse files
committed
fix test
fixes
1 parent fe1909b commit 416fcde

File tree

2 files changed

+30
-25
lines changed

2 files changed

+30
-25
lines changed

tests/firedrake/regression/test_cross_mesh_non_lagrange.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from finat.quadrature import QuadratureRule
55
from functools import partial
66

7+
from ufl.compound_expressions import deviatoric_expr_2x2
8+
79

810
def fs_shape(V):
911
shape = V.ufl_function_space().value_shape
@@ -30,14 +32,14 @@ def make_quadrature_space(V):
3032
return fs_shape(V)(V.mesh(), element)
3133

3234

33-
@pytest.fixture(params=[("RT", 1), ("RT", 2), ("RT", 3), ("BDM", 1), ("BDM", 2), ("BDM", 3),
34-
("BDFM", 2), ("HHJ", 2),("N1curl", 1), ("N1curl", 2), ("N1curl", 3),
35-
("N2curl", 1), ("N2curl", 2), ("N2curl", 3), ("GLS", 1), ("GLS", 2),
36-
("GLS", 3), ("GLS2", 1), ("GLS2", 2), ("GLS2", 3)],
35+
@pytest.fixture(params=[("RT", 2), ("RT", 3), ("RT", 4), ("BDM", 1), ("BDM", 2), ("BDM", 3),
36+
("BDFM", 2), ("HHJ", 2),("N1curl", 2), ("N1curl", 3), ("N1curl", 4),
37+
("N2curl", 1), ("N2curl", 2), ("N2curl", 3), ("GLS", 2), ("GLS", 3),
38+
("GLS", 4), ("GLS2", 1), ("GLS2", 2), ("GLS2", 3)],
3739
ids=lambda x: f"{x[0]}_{x[1]}")
3840
def V(request):
3941
element, degree = request.param
40-
mesh = UnitSquareMesh(3, 3)
42+
mesh = UnitSquareMesh(8, 8)
4143
return FunctionSpace(mesh, element, degree)
4244

4345
# V_source -> Q -> V_target
@@ -83,14 +85,24 @@ def test_cross_mesh_oneform(V):
8385
def test_cross_mesh_oneform_adjoint(V):
8486
# Can already do Lagrange -> RT adjoint
8587
# V^* -> Q^* -> V_target^*
86-
mesh1 = UnitSquareMesh(7, 7)
88+
mesh1 = UnitSquareMesh(2, 2)
8789
x1 = SpatialCoordinate(mesh1)
88-
V_target = fs_shape(V)(mesh1, "CG", 2)
90+
V_target = fs_shape(V)(mesh1, "CG", 1)
8991

9092
mesh2 = V.mesh()
9193
x2 = SpatialCoordinate(mesh2)
9294

93-
oneform_V = inner(x2, TestFunction(V)) * dx
95+
if len(V.value_shape) > 1:
96+
expr = outer(x2, x2)
97+
target_expr = outer(x1, x1)
98+
if V.ufl_element().mapping() == "covariant contravariant Piola":
99+
expr = deviatoric_expr_2x2(expr)
100+
target_expr = deviatoric_expr_2x2(target_expr)
101+
else:
102+
expr = x2
103+
target_expr = x1
104+
105+
oneform_V = inner(expr, TestFunction(V)) * dx
94106

95107
Q_target = make_quadrature_space(V)
96108

@@ -104,9 +116,8 @@ def test_cross_mesh_oneform_adjoint(V):
104116

105117
# cofunc_V = assemble(interpolate(TestFunction(V_target), oneform_target)) # V^* -> V_target^*
106118

107-
cofunc_V_direct = assemble(inner(x1, TestFunction(V_target)) * dx)
108-
119+
cofunc_V_direct = assemble(inner(target_expr, TestFunction(V_target)) * dx)
109120
assert np.allclose(cofunc_V.dat.data_ro, cofunc_V_direct.dat.data_ro)
110121

111122
if __name__ == "__main__":
112-
pytest.main([__file__ + "::test_cross_mesh_oneform_adjoint[RT_1]"])
123+
pytest.main([__file__ + "::test_cross_mesh_oneform_adjoint[GLS_3]"])

tsfc/driver.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,26 +228,17 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
228228

229229
if isinstance(to_element, (PhysicallyMappedElement, DirectlyDefinedElement)):
230230
raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry")
231-
232-
if domain is None:
233-
domain = extract_unique_domain(expression)
234-
assert domain is not None
235231

236232
orig_coefficients = extract_coefficients(expression)
237-
v, operand = expression.argument_slots()
233+
if isinstance(expression, ufl.Interpolate):
234+
v, operand = expression.argument_slots()
235+
else:
236+
operand = expression
237+
v = ufl.FunctionSpace(extract_unique_domain(operand), ufl_element)
238238

239239
# Map into reference space
240240
operand = apply_mapping(operand, ufl_element, domain)
241241

242-
if ufl_element.mapping() != "identity":
243-
# Need to map dual argument for adjoint interpolation
244-
ref_element = finat.ufl.WithMapping(ufl_element, "identity")
245-
V = ufl.FunctionSpace(domain, ref_element)
246-
if isinstance(v, ufl.Coargument):
247-
v = ufl.Coargument(V.dual(), v.number())
248-
else:
249-
v = ufl.Cofunction(V.dual())
250-
251242
# Apply UFL preprocessing
252243
operand = ufl_utils.preprocess_expression(operand, complex_mode=complex_mode)
253244

@@ -266,6 +257,9 @@ def compile_expression_dual_evaluation(expression, to_element, ufl_element, *,
266257
assert len(argument_multiindices) == len(arguments)
267258

268259
# Replace coordinates (if any) unless otherwise specified by kwarg
260+
if domain is None:
261+
domain = extract_unique_domain(expression)
262+
assert domain is not None
269263
builder._domain_integral_type_map = {domain: "cell"}
270264
builder._entity_ids = {domain: (0,)}
271265

0 commit comments

Comments
 (0)