Skip to content

Commit fb3078f

Browse files
authored
BDDC: cellwise subdomains, matfree MatIS, and user-defined primal vertices (#4757)
1 parent 0f6aa11 commit fb3078f

File tree

2 files changed

+273
-51
lines changed

2 files changed

+273
-51
lines changed

firedrake/preconditioners/bddc.py

Lines changed: 181 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from itertools import repeat
2+
13
from firedrake.preconditioners.base import PCBase
24
from firedrake.preconditioners.patch import bcdofs
35
from firedrake.preconditioners.facet_split import get_restriction_indices
@@ -6,11 +8,16 @@
68
from firedrake.ufl_expr import TestFunction, TrialFunction
79
from firedrake.function import Function
810
from firedrake.functionspace import FunctionSpace, VectorFunctionSpace, TensorFunctionSpace
9-
from firedrake.preconditioners.fdm import tabulate_exterior_derivative
11+
from firedrake.preconditioners.fdm import broken_function, tabulate_exterior_derivative
1012
from firedrake.preconditioners.hiptmair import curl_to_grad
11-
from ufl import H1, H2, inner, dx, JacobianDeterminant
13+
from firedrake.parloops import par_loop, INC, READ
14+
from firedrake.utils import cached_property
15+
from firedrake.bcs import DirichletBC
16+
from firedrake.mesh import Submesh
17+
from ufl import Form, H1, H2, JacobianDeterminant, dx, inner, replace
18+
from finat.ufl import BrokenElement
19+
from pyop2.mpi import COMM_SELF
1220
from pyop2.utils import as_tuple
13-
import gem
1421
import numpy
1522

1623
__all__ = ("BDDCPC",)
@@ -23,6 +30,8 @@ class BDDCPC(PCBase):
2330
2431
Internally, this PC creates a PETSc PCBDDC object that can be controlled by
2532
the options:
33+
- ``'bddc_cellwise'`` to set up a MatIS on cellwise subdomains if P.type == python,
34+
- ``'bddc_matfree'`` to set up a matrix-free MatIS if A.type == python,
2635
- ``'bddc_pc_bddc_neumann'`` to set sub-KSPs on subdomains excluding corners,
2736
- ``'bddc_pc_bddc_dirichlet'`` to set sub-KSPs on subdomain interiors,
2837
- ``'bddc_pc_bddc_coarse'`` to set the coarse solver KSP.
@@ -34,36 +43,59 @@ class BDDCPC(PCBase):
3443
- ``'get_divergence_mat'`` for problems in H(div) (resp. 2D H(curl)), this is
3544
provide the arguments (a Mat with the assembled bilinear form testing the divergence
3645
(curl) against an L2 space) and keyword arguments supplied to ``PETSc.PC.setDivergenceMat``.
46+
- ``'primal_markers'`` a Function marking degrees of freedom of the solution space to be included in the
47+
coarse space. Any nonzero value is counted as a marked degree of freedom.
48+
If a DG(0) Function is provided, then all degrees of freedom on the cell are marked.
49+
Alternatively, ``'primal_markers'`` can be a list of the global degrees of freedom to
50+
be supplied directly to ``PETSc.PC.setBDDCPrimalVerticesIS``.
3751
"""
3852

3953
_prefix = "bddc_"
4054

4155
def initialize(self, pc):
42-
# Get context from pc
43-
_, P = pc.getOperators()
44-
dm = pc.getDM()
45-
self.prefix = (pc.getOptionsPrefix() or "") + self._prefix
56+
prefix = (pc.getOptionsPrefix() or "") + self._prefix
4657

58+
dm = pc.getDM()
4759
V = get_function_space(dm)
48-
variant = V.ufl_element().variant()
49-
sobolev_space = V.ufl_element().sobolev_space
5060

5161
# Create new PC object as BDDC type
5262
bddcpc = PETSc.PC().create(comm=pc.comm)
5363
bddcpc.incrementTabLevel(1, parent=pc)
54-
bddcpc.setOptionsPrefix(self.prefix)
55-
bddcpc.setOperators(*pc.getOperators())
64+
bddcpc.setOptionsPrefix(prefix)
5665
bddcpc.setType(PETSc.PC.Type.BDDC)
5766

5867
opts = PETSc.Options(bddcpc.getOptionsPrefix())
68+
matfree = opts.getBool("matfree", False)
69+
70+
# Set operators
71+
assemblers = []
72+
A, P = pc.getOperators()
73+
if P.type == "python":
74+
# Reconstruct P as MatIS
75+
cellwise = opts.getBool("cellwise", False)
76+
P, assembleP = create_matis(P, "aij", cellwise=cellwise)
77+
assemblers.append(assembleP)
78+
79+
if P.type != "is":
80+
raise ValueError(f"Expecting P to be either 'matfree' or 'is', not {P.type}.")
81+
82+
if A.type == "python" and matfree:
83+
# Reconstruct A as MatIS
84+
A, assembleA = create_matis(A, "matfree", cellwise=P.getISAllowRepeated())
85+
assemblers.append(assembleA)
86+
bddcpc.setOperators(A, P)
87+
self.assemblers = assemblers
88+
5989
# Do not use CSR of local matrix to define dofs connectivity unless requested
6090
# Using the CSR only makes sense for H1/H2 problems
61-
is_h1h2 = sobolev_space in [H1, H2]
62-
if "pc_bddc_use_local_mat_graph" not in opts and (not is_h1h2 or not is_lagrange(V.finat_element)):
91+
is_h1h2 = V.ufl_element().sobolev_space in {H1, H2}
92+
if "pc_bddc_use_local_mat_graph" not in opts and (not is_h1h2 or not V.finat_element.has_pointwise_dual_basis):
6393
opts["pc_bddc_use_local_mat_graph"] = False
6494

65-
# Handle boundary dofs
95+
# Get context from DM
6696
ctx = get_appctx(dm)
97+
98+
# Handle boundary dofs
6799
bcs = tuple(ctx._problem.dirichlet_bcs())
68100
mesh = V.mesh().unique()
69101
if mesh.extruded and not mesh.extruded_periodic:
@@ -85,16 +117,17 @@ def initialize(self, pc):
85117
bddcpc.setBDDCNeumannBoundaries(neu_bndr)
86118

87119
appctx = self.get_appctx(pc)
88-
degree = max(as_tuple(V.ufl_element().degree()))
89120

90121
# Set coordinates only if corner selection is requested
91122
# There's no API to query from PC
92123
if "pc_bddc_corner_selection" in opts:
93-
W = VectorFunctionSpace(V.mesh(), "Lagrange", degree, variant=variant)
94-
coords = Function(W).interpolate(V.mesh().coordinates)
124+
degree = max(as_tuple(V.ufl_element().degree()))
125+
variant = V.ufl_element().variant()
126+
W = VectorFunctionSpace(mesh, "Lagrange", degree, variant=variant)
127+
coords = Function(W).interpolate(mesh.coordinates)
95128
bddcpc.setCoordinates(coords.dat.data_ro.repeat(V.block_size, axis=0))
96129

97-
tdim = V.mesh().topological_dimension
130+
tdim = mesh.topological_dimension
98131
if tdim >= 2 and V.finat_element.formdegree == tdim-1:
99132
allow_repeated = P.getISAllowRepeated()
100133
get_divergence = appctx.get("get_divergence_mat", get_divergence_mat)
@@ -116,14 +149,22 @@ def initialize(self, pc):
116149
grad_kwargs = dict()
117150
bddcpc.setBDDCDiscreteGradient(*grad_args, **grad_kwargs)
118151

152+
# Set the user-defined primal (coarse) degrees of freedom
153+
primal_markers = appctx.get("primal_markers")
154+
if primal_markers is not None:
155+
primal_indices = get_primal_indices(V, primal_markers)
156+
primal_is = PETSc.IS().createGeneral(primal_indices.astype(PETSc.IntType), comm=pc.comm)
157+
bddcpc.setBDDCPrimalVerticesIS(primal_is)
158+
119159
bddcpc.setFromOptions()
120160
self.pc = bddcpc
121161

122162
def view(self, pc, viewer=None):
123163
self.pc.view(viewer=viewer)
124164

125165
def update(self, pc):
126-
pass
166+
for c in self.assemblers:
167+
c()
127168

128169
def apply(self, pc, x, y):
129170
self.pc.apply(x, y)
@@ -132,6 +173,104 @@ def applyTranspose(self, pc, x, y):
132173
self.pc.applyTranspose(x, y)
133174

134175

176+
class BrokenDirichletBC(DirichletBC):
177+
def __init__(self, bc):
178+
self.bc = bc
179+
V = bc.function_space().broken_space()
180+
g = bc._original_arg
181+
super().__init__(V, g, bc.sub_domain)
182+
183+
@cached_property
184+
def nodes(self):
185+
u = Function(self.bc.function_space())
186+
self.bc.set(u, 1)
187+
u = broken_function(u.function_space(), val=u.dat)
188+
return numpy.flatnonzero(u.dat.data)
189+
190+
191+
def create_matis(Amat, local_mat_type, cellwise=False):
192+
from firedrake.assemble import get_assembler
193+
194+
def local_mesh(mesh):
195+
key = "local_submesh"
196+
cache = mesh._shared_data_cache["local_submesh_cache"]
197+
try:
198+
return cache[key]
199+
except KeyError:
200+
if mesh.comm.size > 1:
201+
submesh = Submesh(mesh, mesh.topological_dimension, None, ignore_halo=True, reorder=False, comm=COMM_SELF)
202+
else:
203+
submesh = None
204+
return cache.setdefault(key, submesh)
205+
206+
def local_space(V, cellwise):
207+
mesh = local_mesh(V.mesh().unique())
208+
element = BrokenElement(V.ufl_element()) if cellwise else None
209+
return V.reconstruct(mesh=mesh, element=element)
210+
211+
def local_argument(arg, cellwise):
212+
return arg.reconstruct(function_space=local_space(arg.function_space(), cellwise))
213+
214+
def local_integral(it):
215+
extra_domain_integral_type_map = dict(it.extra_domain_integral_type_map())
216+
extra_domain_integral_type_map[it.ufl_domain()] = it.integral_type()
217+
return it.reconstruct(domain=local_mesh(it.ufl_domain()),
218+
extra_domain_integral_type_map=extra_domain_integral_type_map)
219+
220+
def local_bc(bc, cellwise):
221+
V = bc.function_space()
222+
Vsub = local_space(V, False)
223+
sub_domain = list(bc.sub_domain)
224+
if "on_boundary" in sub_domain:
225+
sub_domain.remove("on_boundary")
226+
sub_domain.extend(V.mesh().unique().exterior_facets.unique_markers)
227+
228+
valid_markers = Vsub.mesh().unique().exterior_facets.unique_markers
229+
sub_domain = list(set(sub_domain) & set(valid_markers))
230+
bc = bc.reconstruct(V=Vsub, g=0, sub_domain=sub_domain)
231+
if cellwise:
232+
bc = BrokenDirichletBC(bc)
233+
return bc
234+
235+
def local_to_global_map(V, cellwise):
236+
u = Function(V)
237+
u.dat.data_wo[:] = numpy.arange(*V.dof_dset.layout_vec.getOwnershipRange())
238+
239+
Vsub = local_space(V, False)
240+
usub = Function(Vsub).assign(u)
241+
if cellwise:
242+
usub = broken_function(usub.function_space(), val=usub.dat)
243+
indices = usub.dat.data_ro.astype(PETSc.IntType)
244+
return PETSc.LGMap().create(indices, comm=V.comm)
245+
246+
assert Amat.type == "python"
247+
ctx = Amat.getPythonContext()
248+
form = ctx.a
249+
bcs = ctx.bcs
250+
251+
local_form = replace(form, {arg: local_argument(arg, cellwise) for arg in form.arguments()})
252+
local_form = Form(list(map(local_integral, local_form.integrals())))
253+
local_bcs = tuple(map(local_bc, bcs, repeat(cellwise)))
254+
255+
assembler = get_assembler(local_form, bcs=local_bcs, mat_type=local_mat_type)
256+
tensor = assembler.assemble()
257+
258+
rmap = local_to_global_map(form.arguments()[0].function_space(), cellwise)
259+
cmap = local_to_global_map(form.arguments()[1].function_space(), cellwise)
260+
261+
Amatis = PETSc.Mat().createIS(Amat.getSizes(), comm=Amat.getComm())
262+
Amatis.setISAllowRepeated(cellwise)
263+
Amatis.setLGMap(rmap, cmap)
264+
Amatis.setISLocalMat(tensor.petscmat)
265+
Amatis.setUp()
266+
Amatis.assemble()
267+
268+
def update():
269+
assembler.assemble(tensor=tensor)
270+
Amatis.assemble()
271+
return Amatis, update
272+
273+
135274
def get_restricted_dofs(V, domain):
136275
W = FunctionSpace(V.mesh(), V.ufl_element()[domain])
137276
indices = get_restriction_indices(V, W)
@@ -166,7 +305,7 @@ def get_discrete_gradient(V):
166305
nsp = VectorSpaceBasis([basis])
167306
nsp.orthonormalize()
168307
gradient.setNullSpace(nsp.nullspace())
169-
if not is_lagrange(Q.finat_element):
308+
if not Q.finat_element.has_pointwise_dual_basis:
170309
vdofs = get_restricted_dofs(Q, "vertex")
171310
gradient.compose('_elements_corners', vdofs)
172311

@@ -176,23 +315,25 @@ def get_discrete_gradient(V):
176315
return grad_args, grad_kwargs
177316

178317

179-
def is_lagrange(finat_element):
180-
"""Returns whether finat_element.dual_basis consists only of point evaluation dofs."""
181-
try:
182-
Q, ps = finat_element.dual_basis
183-
except NotImplementedError:
184-
return False
185-
# Inspect the weight matrix
186-
# Lagrange elements have gem.Delta as the only terminal nodes
187-
children = [Q]
188-
while children:
189-
nodes = []
190-
for c in children:
191-
if isinstance(c, gem.Delta):
192-
pass
193-
elif isinstance(c, gem.gem.Terminal):
194-
return False
195-
else:
196-
nodes.extend(c.children)
197-
children = nodes
198-
return True
318+
def get_primal_indices(V, primal_markers):
319+
if isinstance(primal_markers, Function):
320+
marker_space = primal_markers.function_space()
321+
if marker_space == V:
322+
markers = primal_markers
323+
elif marker_space.finat_element.space_dimension() == 1:
324+
shapes = (V.finat_element.space_dimension(), V.block_size)
325+
domain = "{[i,j]: 0 <= i < %d and 0 <= j < %d}" % shapes
326+
instructions = """
327+
for i, j
328+
w[i,j] = w[i,j] + t[0]
329+
end
330+
"""
331+
markers = Function(V)
332+
par_loop((domain, instructions), dx, {"w": (markers, INC), "t": (primal_markers, READ)})
333+
else:
334+
raise ValueError(f"Expecting markers in either {V.ufl_element()} or DG(0).")
335+
primal_indices = numpy.flatnonzero(markers.dat.data >= 1E-12)
336+
primal_indices += V.dof_dset.layout_vec.getOwnershipRange()[0]
337+
else:
338+
primal_indices = numpy.asarray(primal_markers, dtype=PETSc.IntType)
339+
return primal_indices

0 commit comments

Comments
 (0)