44from finat .quadrature import QuadratureRule
55from functools import partial
66
7- from ufl .compound_expressions import deviatoric_expr_2x2
8-
97
108def fs_shape (V ):
119 shape = V .ufl_function_space ().value_shape
@@ -44,34 +42,8 @@ def V(request):
4442 return FunctionSpace (mesh , element , degree )
4543
4644
47- def test_cross_mesh_oneform (V ):
48- mesh1 = UnitSquareMesh (5 , 5 )
49- mesh2 = V .mesh ()
50- x , y = SpatialCoordinate (mesh1 )
51- x1 , y1 = SpatialCoordinate (mesh2 )
52-
53- shape = V .ufl_function_space ().value_shape
54- if len (shape ) == 1 :
55- fs_type = partial (VectorFunctionSpace , dim = shape [0 ])
56- expr1 = as_vector ([x , y ])
57- expr2 = as_vector ([x1 , y1 ])
58- elif len (shape ) == 2 :
59- fs_type = partial (TensorFunctionSpace , shape = shape )
60- expr1 = as_tensor ([[x , x * y ], [x * y , y ]])
61- expr2 = as_tensor ([[x1 , x1 * y1 ], [x1 * y1 , y1 ]])
62- else :
63- raise ValueError ("Unsupported target space shape" )
64-
65- V_source = fs_type (mesh1 , "CG" , 2 )
66- f_source = Function (V_source ).interpolate (expr1 )
67-
68- f_target = assemble (interpolate (f_source , V ))
69- f_direct = Function (V ).interpolate (expr2 )
70-
71- assert np .allclose (f_target .dat .data_ro , f_direct .dat .data_ro )
72-
73-
74- def test_cross_mesh_twoform (V ):
45+ @pytest .mark .parametrize ("rank" , [1 , 2 ])
46+ def test_cross_mesh (V , rank ):
7547 mesh1 = UnitSquareMesh (5 , 5 )
7648 mesh2 = V .mesh ()
7749 x , y = SpatialCoordinate (mesh1 )
@@ -90,17 +62,42 @@ def test_cross_mesh_twoform(V):
9062 raise ValueError ("Unsupported target space shape" )
9163
9264 V_source = fs_type (mesh1 , "CG" , 2 )
93-
94- I = assemble (interpolate (TrialFunction (V_source ), V )) # V_source x V^* -> R
95-
9665 f_source = Function (V_source ).interpolate (expr1 )
9766 f_direct = Function (V ).interpolate (expr2 )
9867
99- f_interpolated = assemble (action (I , f_source ))
100- assert np .allclose (f_interpolated .dat .data_ro , f_direct .dat .data_ro )
101-
102-
103- def test_cross_mesh_oneform_adjoint (V ):
68+ Q = make_quadrature_space (V )
69+
70+ if rank == 2 :
71+ # Assemble the operator
72+ I1 = interpolate (TrialFunction (V_source ), Q ) # V_source x Q_target^* -> R
73+ I2 = interpolate (TrialFunction (Q ), V ) # Q_target x V^* -> R
74+ I_manual = assemble (action (I2 , I1 )) # V_source x V^* -> R
75+ assert I_manual .arguments () == (TestFunction (V .dual ()), TrialFunction (V_source ))
76+ # Direct assembly
77+ I_direct = assemble (interpolate (TrialFunction (V_source ), V )) # V_source
78+ assert I_direct .arguments () == (TestFunction (V .dual ()), TrialFunction (V_source ))
79+
80+ f_interpolated_manual = assemble (action (I_manual , f_source ))
81+ assert np .allclose (f_interpolated_manual .dat .data_ro , f_direct .dat .data_ro )
82+ f_interpolated_direct = assemble (action (I_direct , f_source ))
83+ assert np .allclose (f_interpolated_direct .dat .data_ro , f_direct .dat .data_ro )
84+ elif rank == 1 :
85+ # Interp V_source -> Q
86+ I1 = interpolate (f_source , Q ) # SameMesh
87+ f_quadrature = assemble (I1 )
88+ # Interp Q -> V
89+ I2 = interpolate (f_quadrature , V ) # CrossMesh
90+ f_interpolated_manual = assemble (I2 )
91+ assert f_interpolated_manual .function_space () == V
92+ assert np .allclose (f_interpolated_manual .dat .data_ro , f_direct .dat .data_ro )
93+
94+ f_interpolated_direct = assemble (interpolate (f_source , V ))
95+ assert f_interpolated_direct .function_space () == V
96+ assert np .allclose (f_interpolated_direct .dat .data_ro , f_direct .dat .data_ro )
97+
98+
99+ @pytest .mark .parametrize ("rank" , [1 , 2 ])
100+ def test_cross_mesh_adjoint (V , rank ):
104101 # Can already do Lagrange -> RT adjoint
105102 # V^* -> Q^* -> V_target^*
106103 mesh1 = UnitSquareMesh (2 , 2 )
@@ -114,59 +111,40 @@ def test_cross_mesh_oneform_adjoint(V):
114111 expr = outer (x2 , x2 )
115112 target_expr = outer (x1 , x1 )
116113 if V .ufl_element ().mapping () == "covariant contravariant Piola" :
117- expr = deviatoric_expr_2x2 (expr )
118- target_expr = deviatoric_expr_2x2 (target_expr )
114+ expr = dev (expr )
115+ target_expr = dev (target_expr )
119116 else :
120117 expr = x2
121118 target_expr = x1
122119
123120 oneform_V = inner (expr , TestFunction (V )) * dx
124-
125- Q_target = make_quadrature_space (V )
126-
127- # Interp V^* -> Q^*
128- I1_adj = interpolate (TestFunction (Q_target ), oneform_V ) # SameMesh
129- cofunc_Q = assemble (I1_adj )
130-
131- # Interp Q^* -> V_target^*
132- I2_adj = interpolate (TestFunction (V_target ), cofunc_Q ) # CrossMesh
133- cofunc_Vtarget_manual = assemble (I2_adj )
134-
135- cofunc_Vtarget = assemble (interpolate (TestFunction (V_target ), oneform_V )) # V^* -> V_target^*
136- assert np .allclose (cofunc_Vtarget_manual .dat .data_ro , cofunc_Vtarget .dat .data_ro )
137-
138121 cofunc_Vtarget_direct = assemble (inner (target_expr , TestFunction (V_target )) * dx )
139- assert np .allclose (cofunc_Vtarget .dat .data_ro , cofunc_Vtarget_direct .dat .data_ro )
140-
141-
142- def test_cross_mesh_twoform_adjoint (V ):
143- # V^* -> Q^* -> V_target^*
144- mesh1 = UnitSquareMesh (2 , 2 )
145- x1 = SpatialCoordinate (mesh1 )
146- V_target = fs_shape (V )(mesh1 , "CG" , 1 )
147- mesh2 = V .mesh ()
148- x2 = SpatialCoordinate (mesh2 )
149-
150- if len (V .value_shape ) > 1 :
151- expr = outer (x2 , x2 )
152- target_expr = outer (x1 , x1 )
153- if V .ufl_element ().mapping () == "covariant contravariant Piola" :
154- expr = deviatoric_expr_2x2 (expr )
155- target_expr = deviatoric_expr_2x2 (target_expr )
156- else :
157- expr = x2
158- target_expr = x1
159-
160- oneform_V = inner (expr , TestFunction (V )) * dx
161-
162- I = assemble (interpolate (TestFunction (V_target ), V )) # V^* x V_target -> R
163- assert I .arguments () == (TestFunction (V_target ), TrialFunction (V .dual ()))
164-
165- cofunc_V = assemble (action (I , oneform_V ))
166- cofunc_V_direct = assemble (inner (target_expr , TestFunction (V_target )) * dx )
167-
168- assert np .allclose (cofunc_V .dat .data_ro , cofunc_V_direct .dat .data_ro )
169-
170122
171- if __name__ == "__main__" :
172- pytest .main ([__file__ + "::test_cross_mesh_oneform_adjoint[RT_2]" ])
123+ Q = make_quadrature_space (V )
124+
125+ if rank == 2 :
126+ # Assemble the operator
127+ I1 = interpolate (TestFunction (Q ), V ) # V^* x Q -> R
128+ I2 = interpolate (TestFunction (V_target ), Q ) # Q^* x V_target -> R
129+ I_manual = assemble (action (I2 , I1 )) # V^* x V_target -> R
130+ assert I_manual .arguments () == (TestFunction (V_target ), TrialFunction (V .dual ()))
131+ # Direct assembly
132+ I_direct = assemble (interpolate (TestFunction (V_target ), V )) # V^* x V_target -> R
133+ assert I_direct .arguments () == (TestFunction (V_target ), TrialFunction (V .dual ()))
134+
135+ cofunc_Vtarget_manual = assemble (action (I_manual , oneform_V ))
136+ assert np .allclose (cofunc_Vtarget_manual .dat .data_ro , cofunc_Vtarget_direct .dat .data_ro )
137+ cofunc_Vtarget = assemble (action (I_direct , oneform_V ))
138+ assert np .allclose (cofunc_Vtarget .dat .data_ro , cofunc_Vtarget_direct .dat .data_ro )
139+ elif rank == 1 :
140+ # Interp V^* -> Q^*
141+ I1_adj = interpolate (TestFunction (Q ), oneform_V ) # SameMesh
142+ cofunc_Q = assemble (I1_adj )
143+
144+ # Interp Q^* -> V_target^*
145+ I2_adj = interpolate (TestFunction (V_target ), cofunc_Q ) # CrossMesh
146+ cofunc_Vtarget_manual = assemble (I2_adj )
147+ assert np .allclose (cofunc_Vtarget_manual .dat .data_ro , cofunc_Vtarget_direct .dat .data_ro )
148+
149+ cofunc_Vtarget = assemble (interpolate (TestFunction (V_target ), oneform_V )) # V^* -> V_target^*
150+ assert np .allclose (cofunc_Vtarget .dat .data_ro , cofunc_Vtarget_direct .dat .data_ro )
0 commit comments