@@ -181,21 +181,21 @@ def coarsen(self, fdm, comm):
181181 assert parent is not None
182182
183183 test , trial = fctx .J .arguments ()
184- fV = test .function_space ()
184+ fV = trial .function_space ()
185185 cele = self .coarsen_element (fV .ufl_element ())
186186
187187 # Have we already done this?
188188 cctx = fctx ._coarse
189189 if cctx is not None :
190- cV = cctx .J .arguments ()[0 ].function_space ()
191- if (cV .ufl_element () == cele ) and (cV .mesh () == fV .mesh ()):
190+ cV = cctx .J .arguments ()[1 ].function_space ()
191+ if (cV .ufl_element () == cele ) and (cV .mesh () == fV .mesh ()) and all ( cV_ . boundary_set == fV_ . boundary_set for cV_ , fV_ in zip ( cV , fV )) :
192192 return cV .dm
193193
194- cV = firedrake . FunctionSpace ( fV .mesh (), cele )
194+ cV = fV .reconstruct ( element = cele )
195195 cdm = cV .dm
196196
197197 fproblem = fctx ._problem
198- fu = fproblem .u
198+ fu = fproblem .u_restrict
199199 cu = firedrake .Function (cV )
200200
201201 fdeg = PMGBase .max_degree (fV .ufl_element ())
@@ -370,8 +370,8 @@ def create_transfer(self, mat_type, cctx, fctx, cbcs, fbcs):
370370 construct_mat = prolongation_matrix_aij
371371 else :
372372 raise ValueError ("Unknown matrix type" )
373- cV = cctx .J . arguments ()[ 0 ] .function_space ()
374- fV = fctx .J . arguments ()[ 0 ] .function_space ()
373+ cV = cctx ._problem . u_restrict .function_space ()
374+ fV = fctx ._problem . u_restrict .function_space ()
375375 cbcs = tuple (cctx ._problem .bcs ) if cbcs else tuple ()
376376 fbcs = tuple (fctx ._problem .bcs ) if fbcs else tuple ()
377377 return cache .setdefault (key , construct_mat (cV , fV , cbcs , fbcs ))
@@ -1179,7 +1179,7 @@ def make_permutation_code(V, vshape, pshape, t_in, t_out, array_name):
11791179
11801180def reference_value_space (V ):
11811181 element = finat .ufl .WithMapping (V .ufl_element (), mapping = "identity" )
1182- return firedrake . FunctionSpace ( V . mesh (), element )
1182+ return V . collapse (). reconstruct ( element = element )
11831183
11841184
11851185class StandaloneInterpolationMatrix (object ):
@@ -1206,13 +1206,13 @@ def __init__(self, Vc, Vf, Vc_bcs, Vf_bcs):
12061206 self .Vf = reference_value_space (self .Vf )
12071207 self .uc = firedrake .Function (self .Vc , val = self .uc .dat )
12081208 self .uf = firedrake .Function (self .Vf , val = self .uf .dat )
1209- self .Vc_bcs = [bc .reconstruct (V = self .Vc ) for bc in self .Vc_bcs ]
1210- self .Vf_bcs = [bc .reconstruct (V = self .Vf ) for bc in self .Vf_bcs ]
1209+ self .Vc_bcs = [bc .reconstruct (V = self .Vc , g = 0 ) for bc in self .Vc_bcs ]
1210+ self .Vf_bcs = [bc .reconstruct (V = self .Vf , g = 0 ) for bc in self .Vf_bcs ]
12111211
12121212 def work_function (self , V ):
12131213 if isinstance (V , firedrake .Function ):
12141214 return V
1215- key = (V .ufl_element (), V .mesh ())
1215+ key = (V .ufl_element (), V .mesh (), tuple ( V . boundary_set ) )
12161216 try :
12171217 return self ._cache_work [key ]
12181218 except KeyError :
@@ -1337,17 +1337,14 @@ def make_blas_kernels(self, Vf, Vc):
13371337 restrict = ["" ]* 5
13381338 # get embedding element for Vf with identity mapping and collocated vector component DOFs
13391339 try :
1340- qelem = felem
1341- if qelem .mapping () != "identity" :
1342- qelem = qelem .reconstruct (mapping = "identity" )
1343- Qf = Vf if qelem == felem else firedrake .FunctionSpace (Vf .mesh (), qelem )
1340+ Qf = Vf if felem .mapping () == "identity" else Vf .reconstruct (mapping = "identity" )
13441341 mapping_output = make_mapping_code (Qf , cmapping , fmapping , "t0" , "t1" )
13451342 in_place_mapping = True
13461343 except Exception :
13471344 qelem = finat .ufl .FiniteElement ("DQ" , cell = felem .cell , degree = PMGBase .max_degree (felem ))
13481345 if Vf .value_shape :
13491346 qelem = finat .ufl .TensorElement (qelem , shape = Vf .value_shape , symmetry = felem .symmetry ())
1350- Qf = firedrake . FunctionSpace ( Vf .mesh (), qelem )
1347+ Qf = Vf .reconstruct ( element = qelem )
13511348 mapping_output = make_mapping_code (Qf , cmapping , fmapping , "t0" , "t1" )
13521349
13531350 qshape = (Qf .block_size , Qf .finat_element .space_dimension ())
0 commit comments