1+ from itertools import repeat
2+
13from firedrake .preconditioners .base import PCBase
24from firedrake .preconditioners .patch import bcdofs
35from firedrake .preconditioners .facet_split import get_restriction_indices
68from firedrake .ufl_expr import TestFunction , TrialFunction
79from firedrake .function import Function
810from firedrake .functionspace import FunctionSpace , VectorFunctionSpace , TensorFunctionSpace
9- from firedrake .preconditioners .fdm import tabulate_exterior_derivative
11+ from firedrake .preconditioners .fdm import broken_function , tabulate_exterior_derivative
1012from firedrake .preconditioners .hiptmair import curl_to_grad
11- from ufl import H1 , H2 , inner , dx , JacobianDeterminant
13+ from firedrake .parloops import par_loop , INC , READ
14+ from firedrake .utils import cached_property
15+ from firedrake .bcs import DirichletBC
16+ from firedrake .mesh import Submesh
17+ from ufl import Form , H1 , H2 , JacobianDeterminant , dx , inner , replace
18+ from finat .ufl import BrokenElement
19+ from pyop2 .mpi import COMM_SELF
1220from pyop2 .utils import as_tuple
13- import gem
1421import numpy
1522
1623__all__ = ("BDDCPC" ,)
@@ -23,6 +30,8 @@ class BDDCPC(PCBase):
2330
2431 Internally, this PC creates a PETSc PCBDDC object that can be controlled by
2532 the options:
33+ - ``'bddc_cellwise'`` to set up a MatIS on cellwise subdomains if P.type == python,
34+ - ``'bddc_matfree'`` to set up a matrix-free MatIS if A.type == python,
2635 - ``'bddc_pc_bddc_neumann'`` to set sub-KSPs on subdomains excluding corners,
2736 - ``'bddc_pc_bddc_dirichlet'`` to set sub-KSPs on subdomain interiors,
2837 - ``'bddc_pc_bddc_coarse'`` to set the coarse solver KSP.
@@ -34,36 +43,59 @@ class BDDCPC(PCBase):
3443 - ``'get_divergence_mat'`` for problems in H(div) (resp. 2D H(curl)), this is
3544 provide the arguments (a Mat with the assembled bilinear form testing the divergence
3645 (curl) against an L2 space) and keyword arguments supplied to ``PETSc.PC.setDivergenceMat``.
46+ - ``'primal_markers'`` a Function marking degrees of freedom of the solution space to be included in the
47+ coarse space. Any nonzero value is counted as a marked degree of freedom.
48+ If a DG(0) Function is provided, then all degrees of freedom on the cell are marked.
49+ Alternatively, ``'primal_markers'`` can be a list of the global degrees of freedom to
50+ be supplied directly to ``PETSc.PC.setBDDCPrimalVerticesIS``.
3751 """
3852
3953 _prefix = "bddc_"
4054
4155 def initialize (self , pc ):
42- # Get context from pc
43- _ , P = pc .getOperators ()
44- dm = pc .getDM ()
45- self .prefix = (pc .getOptionsPrefix () or "" ) + self ._prefix
56+ prefix = (pc .getOptionsPrefix () or "" ) + self ._prefix
4657
58+ dm = pc .getDM ()
4759 V = get_function_space (dm )
48- variant = V .ufl_element ().variant ()
49- sobolev_space = V .ufl_element ().sobolev_space
5060
5161 # Create new PC object as BDDC type
5262 bddcpc = PETSc .PC ().create (comm = pc .comm )
5363 bddcpc .incrementTabLevel (1 , parent = pc )
54- bddcpc .setOptionsPrefix (self .prefix )
55- bddcpc .setOperators (* pc .getOperators ())
64+ bddcpc .setOptionsPrefix (prefix )
5665 bddcpc .setType (PETSc .PC .Type .BDDC )
5766
5867 opts = PETSc .Options (bddcpc .getOptionsPrefix ())
68+ matfree = opts .getBool ("matfree" , False )
69+
70+ # Set operators
71+ assemblers = []
72+ A , P = pc .getOperators ()
73+ if P .type == "python" :
74+ # Reconstruct P as MatIS
75+ cellwise = opts .getBool ("cellwise" , False )
76+ P , assembleP = create_matis (P , "aij" , cellwise = cellwise )
77+ assemblers .append (assembleP )
78+
79+ if P .type != "is" :
80+ raise ValueError (f"Expecting P to be either 'matfree' or 'is', not { P .type } ." )
81+
82+ if A .type == "python" and matfree :
83+ # Reconstruct A as MatIS
84+ A , assembleA = create_matis (A , "matfree" , cellwise = P .getISAllowRepeated ())
85+ assemblers .append (assembleA )
86+ bddcpc .setOperators (A , P )
87+ self .assemblers = assemblers
88+
5989 # Do not use CSR of local matrix to define dofs connectivity unless requested
6090 # Using the CSR only makes sense for H1/H2 problems
61- is_h1h2 = sobolev_space in [ H1 , H2 ]
62- if "pc_bddc_use_local_mat_graph" not in opts and (not is_h1h2 or not is_lagrange ( V .finat_element ) ):
91+ is_h1h2 = V . ufl_element (). sobolev_space in { H1 , H2 }
92+ if "pc_bddc_use_local_mat_graph" not in opts and (not is_h1h2 or not V .finat_element . has_pointwise_dual_basis ):
6393 opts ["pc_bddc_use_local_mat_graph" ] = False
6494
65- # Handle boundary dofs
95+ # Get context from DM
6696 ctx = get_appctx (dm )
97+
98+ # Handle boundary dofs
6799 bcs = tuple (ctx ._problem .dirichlet_bcs ())
68100 mesh = V .mesh ().unique ()
69101 if mesh .extruded and not mesh .extruded_periodic :
@@ -85,16 +117,17 @@ def initialize(self, pc):
85117 bddcpc .setBDDCNeumannBoundaries (neu_bndr )
86118
87119 appctx = self .get_appctx (pc )
88- degree = max (as_tuple (V .ufl_element ().degree ()))
89120
90121 # Set coordinates only if corner selection is requested
91122 # There's no API to query from PC
92123 if "pc_bddc_corner_selection" in opts :
93- W = VectorFunctionSpace (V .mesh (), "Lagrange" , degree , variant = variant )
94- coords = Function (W ).interpolate (V .mesh ().coordinates )
124+ degree = max (as_tuple (V .ufl_element ().degree ()))
125+ variant = V .ufl_element ().variant ()
126+ W = VectorFunctionSpace (mesh , "Lagrange" , degree , variant = variant )
127+ coords = Function (W ).interpolate (mesh .coordinates )
95128 bddcpc .setCoordinates (coords .dat .data_ro .repeat (V .block_size , axis = 0 ))
96129
97- tdim = V . mesh () .topological_dimension
130+ tdim = mesh .topological_dimension
98131 if tdim >= 2 and V .finat_element .formdegree == tdim - 1 :
99132 allow_repeated = P .getISAllowRepeated ()
100133 get_divergence = appctx .get ("get_divergence_mat" , get_divergence_mat )
@@ -116,14 +149,22 @@ def initialize(self, pc):
116149 grad_kwargs = dict ()
117150 bddcpc .setBDDCDiscreteGradient (* grad_args , ** grad_kwargs )
118151
152+ # Set the user-defined primal (coarse) degrees of freedom
153+ primal_markers = appctx .get ("primal_markers" )
154+ if primal_markers is not None :
155+ primal_indices = get_primal_indices (V , primal_markers )
156+ primal_is = PETSc .IS ().createGeneral (primal_indices .astype (PETSc .IntType ), comm = pc .comm )
157+ bddcpc .setBDDCPrimalVerticesIS (primal_is )
158+
119159 bddcpc .setFromOptions ()
120160 self .pc = bddcpc
121161
122162 def view (self , pc , viewer = None ):
123163 self .pc .view (viewer = viewer )
124164
125165 def update (self , pc ):
126- pass
166+ for c in self .assemblers :
167+ c ()
127168
128169 def apply (self , pc , x , y ):
129170 self .pc .apply (x , y )
@@ -132,6 +173,104 @@ def applyTranspose(self, pc, x, y):
132173 self .pc .applyTranspose (x , y )
133174
134175
176+ class BrokenDirichletBC (DirichletBC ):
177+ def __init__ (self , bc ):
178+ self .bc = bc
179+ V = bc .function_space ().broken_space ()
180+ g = bc ._original_arg
181+ super ().__init__ (V , g , bc .sub_domain )
182+
183+ @cached_property
184+ def nodes (self ):
185+ u = Function (self .bc .function_space ())
186+ self .bc .set (u , 1 )
187+ u = broken_function (u .function_space (), val = u .dat )
188+ return numpy .flatnonzero (u .dat .data )
189+
190+
191+ def create_matis (Amat , local_mat_type , cellwise = False ):
192+ from firedrake .assemble import get_assembler
193+
194+ def local_mesh (mesh ):
195+ key = "local_submesh"
196+ cache = mesh ._shared_data_cache ["local_submesh_cache" ]
197+ try :
198+ return cache [key ]
199+ except KeyError :
200+ if mesh .comm .size > 1 :
201+ submesh = Submesh (mesh , mesh .topological_dimension , None , ignore_halo = True , reorder = False , comm = COMM_SELF )
202+ else :
203+ submesh = None
204+ return cache .setdefault (key , submesh )
205+
206+ def local_space (V , cellwise ):
207+ mesh = local_mesh (V .mesh ().unique ())
208+ element = BrokenElement (V .ufl_element ()) if cellwise else None
209+ return V .reconstruct (mesh = mesh , element = element )
210+
211+ def local_argument (arg , cellwise ):
212+ return arg .reconstruct (function_space = local_space (arg .function_space (), cellwise ))
213+
214+ def local_integral (it ):
215+ extra_domain_integral_type_map = dict (it .extra_domain_integral_type_map ())
216+ extra_domain_integral_type_map [it .ufl_domain ()] = it .integral_type ()
217+ return it .reconstruct (domain = local_mesh (it .ufl_domain ()),
218+ extra_domain_integral_type_map = extra_domain_integral_type_map )
219+
220+ def local_bc (bc , cellwise ):
221+ V = bc .function_space ()
222+ Vsub = local_space (V , False )
223+ sub_domain = list (bc .sub_domain )
224+ if "on_boundary" in sub_domain :
225+ sub_domain .remove ("on_boundary" )
226+ sub_domain .extend (V .mesh ().unique ().exterior_facets .unique_markers )
227+
228+ valid_markers = Vsub .mesh ().unique ().exterior_facets .unique_markers
229+ sub_domain = list (set (sub_domain ) & set (valid_markers ))
230+ bc = bc .reconstruct (V = Vsub , g = 0 , sub_domain = sub_domain )
231+ if cellwise :
232+ bc = BrokenDirichletBC (bc )
233+ return bc
234+
235+ def local_to_global_map (V , cellwise ):
236+ u = Function (V )
237+ u .dat .data_wo [:] = numpy .arange (* V .dof_dset .layout_vec .getOwnershipRange ())
238+
239+ Vsub = local_space (V , False )
240+ usub = Function (Vsub ).assign (u )
241+ if cellwise :
242+ usub = broken_function (usub .function_space (), val = usub .dat )
243+ indices = usub .dat .data_ro .astype (PETSc .IntType )
244+ return PETSc .LGMap ().create (indices , comm = V .comm )
245+
246+ assert Amat .type == "python"
247+ ctx = Amat .getPythonContext ()
248+ form = ctx .a
249+ bcs = ctx .bcs
250+
251+ local_form = replace (form , {arg : local_argument (arg , cellwise ) for arg in form .arguments ()})
252+ local_form = Form (list (map (local_integral , local_form .integrals ())))
253+ local_bcs = tuple (map (local_bc , bcs , repeat (cellwise )))
254+
255+ assembler = get_assembler (local_form , bcs = local_bcs , mat_type = local_mat_type )
256+ tensor = assembler .assemble ()
257+
258+ rmap = local_to_global_map (form .arguments ()[0 ].function_space (), cellwise )
259+ cmap = local_to_global_map (form .arguments ()[1 ].function_space (), cellwise )
260+
261+ Amatis = PETSc .Mat ().createIS (Amat .getSizes (), comm = Amat .getComm ())
262+ Amatis .setISAllowRepeated (cellwise )
263+ Amatis .setLGMap (rmap , cmap )
264+ Amatis .setISLocalMat (tensor .petscmat )
265+ Amatis .setUp ()
266+ Amatis .assemble ()
267+
268+ def update ():
269+ assembler .assemble (tensor = tensor )
270+ Amatis .assemble ()
271+ return Amatis , update
272+
273+
135274def get_restricted_dofs (V , domain ):
136275 W = FunctionSpace (V .mesh (), V .ufl_element ()[domain ])
137276 indices = get_restriction_indices (V , W )
@@ -166,7 +305,7 @@ def get_discrete_gradient(V):
166305 nsp = VectorSpaceBasis ([basis ])
167306 nsp .orthonormalize ()
168307 gradient .setNullSpace (nsp .nullspace ())
169- if not is_lagrange ( Q .finat_element ) :
308+ if not Q .finat_element . has_pointwise_dual_basis :
170309 vdofs = get_restricted_dofs (Q , "vertex" )
171310 gradient .compose ('_elements_corners' , vdofs )
172311
@@ -176,23 +315,25 @@ def get_discrete_gradient(V):
176315 return grad_args , grad_kwargs
177316
178317
179- def is_lagrange (finat_element ):
180- """Returns whether finat_element.dual_basis consists only of point evaluation dofs."""
181- try :
182- Q , ps = finat_element .dual_basis
183- except NotImplementedError :
184- return False
185- # Inspect the weight matrix
186- # Lagrange elements have gem.Delta as the only terminal nodes
187- children = [Q ]
188- while children :
189- nodes = []
190- for c in children :
191- if isinstance (c , gem .Delta ):
192- pass
193- elif isinstance (c , gem .gem .Terminal ):
194- return False
195- else :
196- nodes .extend (c .children )
197- children = nodes
198- return True
318+ def get_primal_indices (V , primal_markers ):
319+ if isinstance (primal_markers , Function ):
320+ marker_space = primal_markers .function_space ()
321+ if marker_space == V :
322+ markers = primal_markers
323+ elif marker_space .finat_element .space_dimension () == 1 :
324+ shapes = (V .finat_element .space_dimension (), V .block_size )
325+ domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes
326+ instructions = """
327+ for i, j
328+ w[i,j] = w[i,j] + t[0]
329+ end
330+ """
331+ markers = Function (V )
332+ par_loop ((domain , instructions ), dx , {"w" : (markers , INC ), "t" : (primal_markers , READ )})
333+ else :
334+ raise ValueError (f"Expecting markers in either { V .ufl_element ()} or DG(0)." )
335+ primal_indices = numpy .flatnonzero (markers .dat .data >= 1E-12 )
336+ primal_indices += V .dof_dset .layout_vec .getOwnershipRange ()[0 ]
337+ else :
338+ primal_indices = numpy .asarray (primal_markers , dtype = PETSc .IntType )
339+ return primal_indices
0 commit comments