Skip to content
47 changes: 20 additions & 27 deletions firedrake/dmhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,43 +363,36 @@ def create_subdm(dm, fields, *args, **kwargs):
:arg fields: The fields in the new sub-DM.
"""
W = get_function_space(dm)
ctx = get_appctx(dm)
coarsen = get_ctx_coarsener(dm)
parent = get_parent(dm)
if len(fields) == 1:
# Subspace is just a single FunctionSpace.
idx, = fields
subdm = W[idx].dm
subspace = W[idx]
iset = W._ises[idx]
add_hook(parent, setup=partial(push_parent, subdm, parent), teardown=partial(pop_parent, subdm, parent),
call_setup=True)

if ctx is not None:
ctx, = ctx.split([(idx, )])
add_hook(parent, setup=partial(push_appctx, subdm, ctx), teardown=partial(pop_appctx, subdm, ctx),
call_setup=True)
add_hook(parent, setup=partial(push_ctx_coarsener, subdm, coarsen), teardown=partial(pop_ctx_coarsener, subdm, coarsen),
call_setup=True)
return iset, subdm
else:
# Need to build an MFS for the subspace
subspace = firedrake.MixedFunctionSpace([W[f] for f in fields])

add_hook(parent, setup=partial(push_parent, subspace.dm, parent), teardown=partial(pop_parent, subspace.dm, parent),
call_setup=True)
# Index set mapping from W into subspace.
iset = PETSc.IS().createGeneral(numpy.concatenate([W._ises[f].indices
for f in fields]),
iset = PETSc.IS().createGeneral(numpy.concatenate([W.dof_dset.field_ises[f].indices for f in fields]),
comm=W._comm)
if ctx is not None:
ctx, = ctx.split([fields])
add_hook(parent, setup=partial(push_appctx, subspace.dm, ctx),
teardown=partial(pop_appctx, subspace.dm, ctx),
call_setup=True)
add_hook(parent, setup=partial(push_ctx_coarsener, subspace.dm, coarsen),
teardown=partial(pop_ctx_coarsener, subspace.dm, coarsen),
call_setup=True)
return iset, subspace.dm

subdm = subspace.dm
parent = get_parent(dm)
add_hook(parent, setup=partial(push_parent, subdm, parent),
teardown=partial(pop_parent, subdm, parent),
call_setup=True)

ctx = get_appctx(dm)
coarsen = get_ctx_coarsener(dm)
if ctx is not None:
ctx, = ctx.split([fields])
add_hook(parent, setup=partial(push_appctx, subdm, ctx),
teardown=partial(pop_appctx, subdm, ctx),
call_setup=True)
add_hook(parent, setup=partial(push_ctx_coarsener, subdm, coarsen),
teardown=partial(pop_ctx_coarsener, subdm, coarsen),
call_setup=True)
return iset, subdm


@PETSc.Log.EventDecorator()
Expand Down
20 changes: 17 additions & 3 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import numpy
import collections

Expand Down Expand Up @@ -161,6 +160,7 @@ def cofunction(self, o):
return Cofunction(W, val=MixedDat(o.dat[i] for i in indices))

def matrix(self, o):
from firedrake.bcs import DirichletBC, EquationBC
ises = []
args = []
for a in o.arguments():
Expand All @@ -180,8 +180,22 @@ def matrix(self, o):
args.append(asplit)

submat = o.petscmat.createSubMatrix(*ises)
bcs = ()
return AssembledMatrix(tuple(args), bcs, submat)
bcs = []
spaces = [a.function_space() for a in o.arguments()]
for bc in o.bcs:
W = bc.function_space()
W = W.parent or W

number = spaces.index(W)
V = args[number].function_space()
field = self.blocks[number]
if isinstance(bc, DirichletBC):
sub_bc = bc.reconstruct(field=field, V=V, g=bc.function_arg)
elif isinstance(bc, EquationBC):
raise NotImplementedError("Please get in touch if you need this")
if sub_bc is not None:
bcs.append(sub_bc)
return AssembledMatrix(tuple(args), tuple(bcs), submat)

def zero_base_form(self, o):
return ZeroBaseForm(tuple(map(self, o.arguments())))
Expand Down
10 changes: 7 additions & 3 deletions firedrake/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,7 @@ def __init__(self, a, bcs, mat_type, *args, **kwargs):
self.mat_type = mat_type

def assemble(self):
raise NotImplementedError("API compatibility to apply bcs after 'assemble(a)'\
has been removed. Use 'assemble(a, bcs=bcs)', which\
now returns an assembled matrix.")
self.M.assemble()


class ImplicitMatrix(MatrixBase):
Expand Down Expand Up @@ -250,3 +248,9 @@ def __init__(self, a, bcs, petscmat, *args, **kwargs):

def mat(self):
return self.petscmat

def assemble(self):
# Bump petsc matrix state by assembling it.
# Ensures that if the matrix changed, the preconditioner is
# updated if necessary.
self.petscmat.assemble()
54 changes: 40 additions & 14 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,15 @@ def split(self, fields):
splits = []
problem = self._problem
splitter = ExtractSubBlock()

Fbig = problem.F
# Reuse the submatrices if we are splitting a MatNest
Jbig = self._jac if self._jac.petscmat.type == "nest" else problem.J
Jpbig = self._pjac if self._pjac.petscmat.type == "nest" else problem.Jp

for field in fields:
F = splitter.split(problem.F, argument_indices=(field, ))
J = splitter.split(problem.J, argument_indices=(field, field))
F = splitter.split(Fbig, argument_indices=(field, ))
J = splitter.split(Jbig, argument_indices=(field, field))
us = problem.u_restrict.subfunctions
V = F.arguments()[0].function_space()
# Exposition:
Expand Down Expand Up @@ -397,7 +403,6 @@ def split(self, fields):
# solving for, and some spaces that have just become
# coefficients in the new form.
u = as_vector(vec)
J = replace(J, {problem.u_restrict: u})
if problem.is_linear and isinstance(J, MatrixBase):
# The BC lifting term is action(MatrixBase, u).
# We cannot replace u with the split solution, as action expects a Function.
Expand All @@ -407,23 +412,35 @@ def split(self, fields):
F += problem.compute_bc_lifting(J, subu)
else:
F = replace(F, {problem.u_restrict: u})
if problem.Jp is not None:
Jp = splitter.split(problem.Jp, argument_indices=(field, field))

J = replace(J, {problem.u_restrict: u})
if Jpbig is not None:
Jp = splitter.split(Jpbig, argument_indices=(field, field))
Jp = replace(Jp, {problem.u_restrict: u})
else:
Jp = None
bcs = []
for bc in problem.bcs:
if isinstance(bc, DirichletBC):
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain)
elif isinstance(bc, EquationBC):
bc_temp = bc.reconstruct(V, subu, u, field, problem.is_linear)
if bc_temp is not None:
bcs.append(bc_temp)

if isinstance(J, MatrixBase) and J.has_bcs:
# The BCs of the problem are already encoded in the Jacobian
bcs = None
else:
bcs = []
for bc in problem.bcs:
if isinstance(bc, DirichletBC):
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg)
elif isinstance(bc, EquationBC):
bc_temp = bc.reconstruct(V, subu, u, field, problem.is_linear)
if bc_temp is not None:
bcs.append(bc_temp)

new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
form_compiler_parameters=problem.form_compiler_parameters)
new_problem._constant_jacobian = problem._constant_jacobian
splits.append(type(self)(new_problem, mat_type=self.mat_type, pmat_type=self.pmat_type,
splits.append(type(self)(new_problem,
mat_type=self.mat_type,
pmat_type=self.pmat_type,
sub_mat_type=self.sub_mat_type,
sub_pmat_type=self.sub_pmat_type,
appctx=self.appctx,
transfer_manager=self.transfer_manager,
pre_apply_bcs=self.pre_apply_bcs))
Expand Down Expand Up @@ -504,6 +521,15 @@ def form_jacobian(snes, X, J, P):
ctx.set_nullspace(ctx._nullspace_T, ises, transpose=True, near=False)
ctx.set_nullspace(ctx._near_nullspace, ises, transpose=False, near=True)

# Bump petsc matrix state of each split by assembling them.
# Ensures that if the matrix changed, the preconditioner is
# updated if necessary.
for fields, splits in ctx._splits.items():
for subctx in splits:
subctx._jac.assemble()
if subctx.Jp is not None:
subctx._pjac.assemble()

@staticmethod
def compute_operators(ksp, J, P):
r"""Form the Jacobian for this problem
Expand Down
43 changes: 43 additions & 0 deletions tests/firedrake/regression/test_nested_fieldsplit_solves.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,49 @@ def test_nested_fieldsplit_solve_parallel(W, A, b, expect):
assert norm(f) < 1e-11


@pytest.mark.parametrize("mat_type,pmat_type", [("nest", "nest"), ("matfree", "nest")])
def test_nonlinear_fieldsplit(mat_type, pmat_type):
mesh = UnitIntervalMesh(1)
V = FunctionSpace(mesh, "DG", 0)
Z = V * V * V

u = Function(Z)
u0, u1, u2 = split(u)
v0, v1, v2 = TestFunctions(Z)

F = inner(u0, v0) * dx
F += inner(0.5*u1**2 + u1, v1) * dx
F += inner(u2, v2) * dx
u.subfunctions[1].assign(Constant(1))

sp = {
"mat_type": mat_type,
"pmat_type": pmat_type,
"snes_max_it": 10,
"ksp_type": "fgmres",
"pc_type": "fieldsplit",
"pc_fieldsplit_type": "additive",
"pc_fieldsplit_0_fields": "0",
"pc_fieldsplit_1_fields": "1,2",
"fieldsplit_1_ksp_view_eigenvalues": None,
"fieldsplit": {
"ksp_type": "gmres",
"pc_type": "jacobi",
},
}
problem = NonlinearVariationalProblem(F, u)
solver = NonlinearVariationalSolver(problem, solver_parameters=sp)

def mymonitor(snes, it, fnorm):
if it == 0:
# This call happens before the first linear solve
return
assert np.allclose(snes.ksp.pc.getFieldSplitSubKSP()[1].computeEigenvalues(), 1)

solver.snes.setMonitor(mymonitor)
solver.solve()


def test_matrix_types(W):
a = inner(TrialFunction(W), TestFunction(W))*dx

Expand Down
25 changes: 25 additions & 0 deletions tests/firedrake/regression/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ def test_assemble_split_mixed_derivative():
assert np.allclose(actual.M.values, expect.M.values)


@pytest.mark.parametrize("mat_type", ("aij", "nest"))
@pytest.mark.parametrize("bcs", (True, False))
def test_split_assembled_matrix(mat_type, bcs):
mesh = UnitSquareMesh(2, 2)
V = FunctionSpace(mesh, "CG", 1)
Q = FunctionSpace(mesh, "DG", 0)
Z = V * Q
bcs = [DirichletBC(Z.sub(0), 0, "on_boundary")] if bcs else []

test = TestFunction(Z)
trial = TrialFunction(Z)

a = inner(test, trial)*dx
A = assemble(a, bcs=bcs, mat_type=mat_type)

splitter = ExtractSubBlock()
actual = splitter.split(A, (0, 0))

bcs = [bc.reconstruct(V=V) for bc in bcs]
expect = assemble(splitter.split(a, (0, 0)), bcs=bcs)

expect.petscmat.axpy(-1, actual.petscmat)
assert np.allclose(expect.petscmat[:, :], 0)


def test_split_coordinate_derivative():
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "P", 1)
Expand Down
Loading