44from finat .quadrature import QuadratureRule
55from functools import partial
66
7+ from ufl .compound_expressions import deviatoric_expr_2x2
8+
79
810def fs_shape (V ):
911 shape = V .ufl_function_space ().value_shape
@@ -30,14 +32,14 @@ def make_quadrature_space(V):
3032 return fs_shape (V )(V .mesh (), element )
3133
3234
33- @pytest .fixture (params = [("RT" , 1 ), ("RT" , 2 ), ("RT" , 3 ), ("BDM" , 1 ), ("BDM" , 2 ), ("BDM" , 3 ),
34- ("BDFM" , 2 ), ("HHJ" , 2 ),("N1curl" , 1 ), ("N1curl" , 2 ), ("N1curl" , 3 ),
35- ("N2curl" , 1 ), ("N2curl" , 2 ), ("N2curl" , 3 ), ("GLS" , 1 ), ("GLS" , 2 ),
36- ("GLS" , 3 ), ("GLS2" , 1 ), ("GLS2" , 2 ), ("GLS2" , 3 )],
35+ @pytest .fixture (params = [("RT" , 2 ), ("RT" , 3 ), ("RT" , 4 ), ("BDM" , 1 ), ("BDM" , 2 ), ("BDM" , 3 ),
36+ ("BDFM" , 2 ), ("HHJ" , 2 ),("N1curl" , 2 ), ("N1curl" , 3 ), ("N1curl" , 4 ),
37+ ("N2curl" , 1 ), ("N2curl" , 2 ), ("N2curl" , 3 ), ("GLS" , 2 ), ("GLS" , 3 ),
38+ ("GLS" , 4 ), ("GLS2" , 1 ), ("GLS2" , 2 ), ("GLS2" , 3 )],
3739 ids = lambda x : f"{ x [0 ]} _{ x [1 ]} " )
3840def V (request ):
3941 element , degree = request .param
40- mesh = UnitSquareMesh (3 , 3 )
42+ mesh = UnitSquareMesh (8 , 8 )
4143 return FunctionSpace (mesh , element , degree )
4244
4345# V_source -> Q -> V_target
@@ -83,14 +85,24 @@ def test_cross_mesh_oneform(V):
8385def test_cross_mesh_oneform_adjoint (V ):
8486 # Can already do Lagrange -> RT adjoint
8587 # V^* -> Q^* -> V_target^*
86- mesh1 = UnitSquareMesh (7 , 7 )
88+ mesh1 = UnitSquareMesh (2 , 2 )
8789 x1 = SpatialCoordinate (mesh1 )
88- V_target = fs_shape (V )(mesh1 , "CG" , 2 )
90+ V_target = fs_shape (V )(mesh1 , "CG" , 1 )
8991
9092 mesh2 = V .mesh ()
9193 x2 = SpatialCoordinate (mesh2 )
9294
93- oneform_V = inner (x2 , TestFunction (V )) * dx
95+ if len (V .value_shape ) > 1 :
96+ expr = outer (x2 , x2 )
97+ target_expr = outer (x1 , x1 )
98+ if V .ufl_element ().mapping () == "covariant contravariant Piola" :
99+ expr = deviatoric_expr_2x2 (expr )
100+ target_expr = deviatoric_expr_2x2 (target_expr )
101+ else :
102+ expr = x2
103+ target_expr = x1
104+
105+ oneform_V = inner (expr , TestFunction (V )) * dx
94106
95107 Q_target = make_quadrature_space (V )
96108
@@ -104,9 +116,8 @@ def test_cross_mesh_oneform_adjoint(V):
104116
105117 # cofunc_V = assemble(interpolate(TestFunction(V_target), oneform_target)) # V^* -> V_target^*
106118
107- cofunc_V_direct = assemble (inner (x1 , TestFunction (V_target )) * dx )
108-
119+ cofunc_V_direct = assemble (inner (target_expr , TestFunction (V_target )) * dx )
109120 assert np .allclose (cofunc_V .dat .data_ro , cofunc_V_direct .dat .data_ro )
110121
111122if __name__ == "__main__" :
112- pytest .main ([__file__ + "::test_cross_mesh_oneform_adjoint[RT_1 ]" ])
123+ pytest .main ([__file__ + "::test_cross_mesh_oneform_adjoint[GLS_3 ]" ])
0 commit comments