Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c55b648
dmhooks: support RestrictedFunctionSpace
pbrubeck Mar 28, 2025
1b20212
add tests
pbrubeck Mar 28, 2025
475fe05
Merge branch 'master' into pbrubeck/fix/restricted-dmhooks
pbrubeck Mar 31, 2025
fd42aa1
RestrictedFunctionSpace: support geometric multigrid
pbrubeck Mar 31, 2025
3c63f6c
RestrictedFunctionSpace: support p-multigrid
pbrubeck Mar 31, 2025
7a07dff
cleanup
pbrubeck Mar 31, 2025
f165b4b
PC: FunctionSpace() -> V.reconstruct()
pbrubeck Mar 31, 2025
6512d47
Merge branch 'master' into pbrubeck/fix/restricted-dmhooks
pbrubeck Apr 10, 2025
c0287de
Merge branch 'pbrubeck/fix/restricted-dmhooks' of github.com:firedrak…
pbrubeck Apr 10, 2025
82b7b74
Update firedrake/preconditioners/patch.py
pbrubeck Apr 10, 2025
511c44b
merge conflict
pbrubeck Jul 31, 2025
8a811fd
Remove RestrictedFunctionSpace.__new__
pbrubeck Jul 31, 2025
0a4ff80
Merge branch 'main' into pbrubeck/fix/restricted-dmhooks
pbrubeck Aug 1, 2025
172ef70
mesh cache key
pbrubeck Aug 8, 2025
602c5a6
small change
pbrubeck Aug 8, 2025
c6f9e20
Update firedrake/mg/utils.py
pbrubeck Aug 19, 2025
6065958
Merge branch 'main' into pbrubeck/fix/restricted-dmhooks
pbrubeck Aug 20, 2025
e74b450
Fixes for RieszMap
pbrubeck Aug 20, 2025
e4e6817
Cache key with frozenset
pbrubeck Aug 20, 2025
4c7b68f
Merge branch 'main' into pbrubeck/fix/restricted-dmhooks
pbrubeck Aug 21, 2025
cfc896e
Cache on both entity_dof_keys
pbrubeck Aug 21, 2025
8f53ed6
Apply suggestions from code review
pbrubeck Aug 29, 2025
92eb89e
Merge branch 'main' into pbrubeck/fix/restricted-dmhooks
pbrubeck Aug 29, 2025
d43e944
Fix cache key
pbrubeck Aug 29, 2025
4457985
Fix reconstruct
pbrubeck Aug 29, 2025
d29f210
cleanup
pbrubeck Aug 29, 2025
ddc189a
Fix cache key
pbrubeck Aug 29, 2025
3e80f46
Fixup
pbrubeck Aug 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal
V = V.sub(index)
if g is None:
g = self._original_arg
if isinstance(g, firedrake.Function) and g.function_space() != V:
g = firedrake.Function(V).interpolate(g)
if sub_domain is None:
sub_domain = self.sub_domain
if field is not None:
Expand Down Expand Up @@ -739,11 +741,11 @@ def restricted_function_space(V, ids):
return V

assert len(ids) == len(V)
spaces = [Vsub if len(boundary_set) == 0 else
firedrake.RestrictedFunctionSpace(Vsub, boundary_set=boundary_set)
for Vsub, boundary_set in zip(V, ids)]
spaces = [V_ if len(boundary_set) == 0 else
firedrake.RestrictedFunctionSpace(V_, boundary_set=boundary_set, name=V_.name)
for V_, boundary_set in zip(V, ids)]

if len(spaces) == 1:
return spaces[0]
else:
return firedrake.MixedFunctionSpace(spaces)
return firedrake.MixedFunctionSpace(spaces, name=V.name)
10 changes: 6 additions & 4 deletions firedrake/dmhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def get_function_space(dm):
:raises RuntimeError: if no function space was found.
"""
info = dm.getAttr("__fs_info__")
meshref, element, indices, (name, names) = info
meshref, element, indices, (name, names), boundary_sets = info
mesh = meshref()
if mesh is None:
raise RuntimeError("Somehow your mesh was collected, this should never happen")
V = firedrake.FunctionSpace(mesh, element, name=name)
if any(boundary_sets):
V = firedrake.bcs.restricted_function_space(V, boundary_sets)
if len(V) > 1:
for V_, name in zip(V, names):
V_.topological.name = name
Expand Down Expand Up @@ -93,8 +95,8 @@ def set_function_space(dm, V):
if len(V) > 1:
names = tuple(V_.name for V_ in V)
element = V.ufl_element()

info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names))
boundary_sets = tuple(V_.boundary_set for V_ in V)
info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names), boundary_sets)
dm.setAttr("__fs_info__", info)


Expand Down Expand Up @@ -457,7 +459,7 @@ def refine(dm, comm):
if hasattr(V, "_fine"):
fdm = V._fine.dm
else:
V._fine = firedrake.FunctionSpace(hierarchy[level + 1], V.ufl_element())
V._fine = V.reconstruct(mesh=hierarchy[level + 1])
fdm = V._fine.dm
V._fine._coarse = V
return fdm
Expand Down
19 changes: 13 additions & 6 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def collapse(self):
return type(self).create(self.topological.collapse(), self.mesh())

@classmethod
def make_function_space(cls, mesh, element, name=None):
def make_function_space(cls, mesh, element, name=None, boundary_set=None):
r"""Factory method for :class:`.WithGeometryBase`."""
mesh.init()
topology = mesh.topology
Expand All @@ -376,12 +376,18 @@ def make_function_space(cls, mesh, element, name=None):
if mesh is not topology:
# Create a concrete WithGeometry or FiredrakeDualSpace on this mesh
new = cls.create(new, mesh)

if boundary_set:
new = RestrictedFunctionSpace(new, boundary_set=boundary_set)
if mesh is not topology:
new = cls.create(new, mesh)
return new

def reconstruct(self, mesh=None, name=None, **kwargs):
def reconstruct(self, mesh=None, element=None, name=None, **kwargs):
r"""Reconstruct this :class:`.WithGeometryBase` .

:kwarg mesh: the new :func:`~.Mesh` (defaults to same mesh)
:kwarg element: the new :class:`finat.ufl.FiniteElement` (defaults to same element)
:kwarg name: the new name (defaults to None)
:returns: the new function space of the same class as ``self``.

Expand All @@ -404,12 +410,14 @@ def reconstruct(self, mesh=None, name=None, **kwargs):
if mesh is None:
mesh = V_parent.mesh()

element = V_parent.ufl_element()
if element is None:
element = V_parent.ufl_element()
cell = mesh.topology.ufl_cell()
if len(kwargs) > 0 or element.cell != cell:
element = element.reconstruct(cell=cell, **kwargs)

V = type(self).make_function_space(mesh, element, name=name)
V = type(self).make_function_space(mesh, element, name=name,
boundary_set=V_parent.boundary_set)
for i in reversed(indices):
V = V.sub(i)
return V
Expand Down Expand Up @@ -901,8 +909,7 @@ def __init__(self, function_space, boundary_set=frozenset(), name=None):
function_space.ufl_element(),
label=self._label)
self.function_space = function_space
self.name = name or (function_space.name or "Restricted" + "_"
+ "_".join(sorted(map(str, self.boundary_set))))
self.name = name or function_space.name

def set_shared_data(self):
sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set)
Expand Down
25 changes: 15 additions & 10 deletions firedrake/mg/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,17 @@ def prolong(coarse, fine):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = coarse_level * refinements_per_level

element = Vc.ufl_element()
meshes = hierarchy._meshes
for j in range(repeat):
next_level += 1
if j == repeat - 1:
next = fine
Vf = fine.function_space()
else:
Vf = firedrake.FunctionSpace(meshes[next_level], element)
Vf = Vc.reconstruct(mesh=meshes[next_level])
next = firedrake.Function(Vf)

coarse_coords = Vc.mesh().coordinates
coarse_coords = get_coordinates(Vc)
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
kernel = kernels.prolong_kernel(coarse)
Expand Down Expand Up @@ -119,7 +118,6 @@ def restrict(fine_dual, coarse_dual):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

element = Vc.ufl_element()
meshes = hierarchy._meshes

for j in range(repeat):
Expand All @@ -128,15 +126,15 @@ def restrict(fine_dual, coarse_dual):
coarse_dual.dat.zero()
next = coarse_dual
else:
Vc = firedrake.FunctionSpace(meshes[next_level], element)
Vc = Vf.reconstruct(mesh=meshes[next_level])
next = firedrake.Cofunction(Vc.dual())
Vc = next.function_space()
# XXX: Should be able to figure out locations by pushing forward
# reference cell node locations to physical space.
# x = \sum_i c_i \phi_i(x_hat)
node_locations = utils.physical_node_locations(Vf)
node_locations = utils.physical_node_locations(Vf.dual())

coarse_coords = Vc.mesh().coordinates
coarse_coords = get_coordinates(Vc.dual())
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
# Have to do this, because the node set core size is not right for
Expand Down Expand Up @@ -195,7 +193,6 @@ def inject(fine, coarse):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

element = Vc.ufl_element()
meshes = hierarchy._meshes

for j in range(repeat):
Expand All @@ -205,12 +202,12 @@ def inject(fine, coarse):
next = coarse
Vc = next.function_space()
else:
Vc = firedrake.FunctionSpace(meshes[next_level], element)
Vc = Vf.reconstruct(mesh=meshes[next_level])
next = firedrake.Function(Vc)
if not dg:
node_locations = utils.physical_node_locations(Vc)

fine_coords = Vf.mesh().coordinates
fine_coords = get_coordinates(Vf)
coarse_node_to_fine_nodes = utils.coarse_node_to_fine_node_map(Vc, Vf)
coarse_node_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space())

Expand Down Expand Up @@ -242,3 +239,11 @@ def inject(fine, coarse):
fine = next
Vf = Vc
return coarse


def get_coordinates(V):
coords = V.mesh().coordinates
if V.boundary_set:
W = V.reconstruct(element=coords.function_space().ufl_element())
coords = firedrake.Function(W).interpolate(coords)
return coords
14 changes: 7 additions & 7 deletions firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
if isinstance(g, firedrake.Function) and hasattr(g, "_child"):
manager.inject(g, g._child)

V = problem.u.function_space()
V = problem.u_restrict.function_space()
if not hasattr(V, "_coarse"):
# The hook is persistent and cumulative, but also problem-independent.
# Therefore, we are only adding it once.
Expand All @@ -201,7 +201,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
for c in coefficients:
coefficient_mapping[c] = self(c, self, coefficient_mapping=coefficient_mapping)

u = coefficient_mapping[problem.u]
u = coefficient_mapping[problem.u_restrict]

bcs = [self(bc, self) for bc in problem.bcs]
J = self(problem.J, self, coefficient_mapping=coefficient_mapping)
Expand Down Expand Up @@ -277,7 +277,7 @@ def coarsen_snescontext(context, self, coefficient_mapping=None):
if isinstance(val, (firedrake.Function, firedrake.Cofunction)):
V = val.function_space()
coarseneddm = V.dm
parentdm = get_parent(context._problem.u.function_space().dm)
parentdm = get_parent(context._problem.u_restrict.function_space().dm)

# Now attach the hook to the parent DM
if get_appctx(coarseneddm) is None:
Expand Down Expand Up @@ -369,8 +369,8 @@ def create_interpolation(dmc, dmf):

manager = get_transfer_manager(dmf)

V_c = cctx._problem.u.function_space()
V_f = fctx._problem.u.function_space()
V_c = cctx._problem.u_restrict.function_space()
V_f = fctx._problem.u_restrict.function_space()

row_size = V_f.dof_dset.layout_vec.getSizes()
col_size = V_c.dof_dset.layout_vec.getSizes()
Expand All @@ -395,8 +395,8 @@ def create_injection(dmc, dmf):

manager = get_transfer_manager(dmf)

V_c = cctx._problem.u.function_space()
V_f = fctx._problem.u.function_space()
V_c = cctx._problem.u_restrict.function_space()
V_f = fctx._problem.u_restrict.function_space()

row_size = V_f.dof_dset.layout_vec.getSizes()
col_size = V_c.dof_dset.layout_vec.getSizes()
Expand Down
7 changes: 4 additions & 3 deletions firedrake/mg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def fine_node_to_coarse_node_map(Vf, Vc):
if len(Vf) > 1:
assert len(Vf) == len(Vc)
return op2.MixedMap(fine_node_to_coarse_node_map(f, c) for f, c in zip(Vf, Vc))
return op2.MixedMap(map(fine_node_to_coarse_node_map, Vf, Vc))
mesh = Vf.mesh()
assert hasattr(mesh, "_shared_data_cache")
hierarchyf, levelf = get_level(Vf.mesh())
Expand Down Expand Up @@ -49,7 +49,7 @@ def fine_node_to_coarse_node_map(Vf, Vc):
def coarse_node_to_fine_node_map(Vc, Vf):
if len(Vf) > 1:
assert len(Vf) == len(Vc)
return op2.MixedMap(coarse_node_to_fine_node_map(f, c) for f, c in zip(Vf, Vc))
return op2.MixedMap(map(coarse_node_to_fine_node_map, Vf, Vc))
mesh = Vc.mesh()
assert hasattr(mesh, "_shared_data_cache")
hierarchyf, levelf = get_level(Vf.mesh())
Expand Down Expand Up @@ -146,7 +146,8 @@ def physical_node_locations(V):
try:
return cache[key]
except KeyError:
Vc = firedrake.VectorFunctionSpace(mesh, element)
Vc = V.reconstruct(element=finat.ufl.VectorElement(element, dim=mesh.geometric_dimension()))

# FIXME: This is unsafe for DG coordinates and CG target spaces.
locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc))
return cache.setdefault(key, locations)
Expand Down
1 change: 0 additions & 1 deletion pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def _wrapper_cache_key_(self):
@utils.validate_in(('access', _modes, ex.ModeValueError))
def __call__(self, access, path=None):
from pyop2.parloop import DatLegacyArg

if conf.configuration["type_check"] and path and path.toset != self.dataset.set:
raise ex.MapValueError("To Set of Map does not match Set of Dat.")
return DatLegacyArg(self, path, access)
Expand Down
94 changes: 94 additions & 0 deletions tests/firedrake/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,97 @@ def test_restricted_function_space_extrusion_stokes(ncells):
# -- Actually, the ordering is the same.
assert np.allclose(sol_res.subfunctions[0].dat.data_ro_with_halos, sol.subfunctions[0].dat.data_ro_with_halos)
assert np.allclose(sol_res.subfunctions[1].dat.data_ro_with_halos, sol.subfunctions[1].dat.data_ro_with_halos)


@pytest.mark.parametrize("names", [(None, None), (None, "name1"), ("name0", "name1")])
def test_restrict_fieldsplit(names):
mesh = UnitSquareMesh(2, 2)
V = FunctionSpace(mesh, "CG", 1, name=names[0])
Q = FunctionSpace(mesh, "CG", 2, name=names[1])
Z = V * Q

z = Function(Z)
test = TestFunction(Z)
z_exact = Constant([1, -1])

F = inner(z - z_exact, test) * dx
bcs = [DirichletBC(Z.sub(i), z_exact[i], (i+1, i+3)) for i in range(len(Z))]

problem = NonlinearVariationalProblem(F, z, bcs=bcs, restrict=True)
solver = NonlinearVariationalSolver(problem, solver_parameters={
"snes_type": "ksponly",
"ksp_type": "preonly",
"pc_type": "fieldsplit",
"pc_fieldsplit_type": "additive",
f"fieldsplit_{names[0] or 0}_pc_type": "lu",
f"fieldsplit_{names[1] or 1}_pc_type": "lu"},
options_prefix="")
solver.solve()

# Test prefixes for the restricted spaces
pc = solver.snes.ksp.pc
for field, ksp in enumerate(pc.getFieldSplitSubKSP()):
name = Z[field].name or field
assert ksp.getOptionsPrefix() == f"fieldsplit_{name}_"

assert errornorm(z_exact[0], z.subfunctions[0]) < 1E-10
assert errornorm(z_exact[1], z.subfunctions[1]) < 1E-10


def test_restrict_python_pc():
mesh = UnitSquareMesh(2, 2)
V = FunctionSpace(mesh, "CG", 1)
u = Function(V)
test = TestFunction(V)

x, y = SpatialCoordinate(mesh)
u_exact = x + y
g = Function(V).interpolate(u_exact)

F = inner(u - u_exact, test) * dx
bcs = [DirichletBC(V, g, 1), DirichletBC(V, u_exact, 2)]

problem = NonlinearVariationalProblem(F, u, bcs=bcs, restrict=True)
solver = NonlinearVariationalSolver(problem, solver_parameters={
"snes_type": "ksponly",
"mat_type": "matfree",
"ksp_type": "preonly",
"pc_type": "python",
"pc_python_type": "firedrake.AssembledPC",
"assembled_pc_type": "lu"})
solver.solve()

assert errornorm(u_exact, u) < 1E-10


def test_restrict_multigrid():
base = UnitSquareMesh(2, 2)
refine = 2
mh = MeshHierarchy(base, refine)
mesh = mh[-1]

V = FunctionSpace(mesh, "CG", 1)
u = Function(V)
test = TestFunction(V)

x, y = SpatialCoordinate(mesh)
u_exact = x + y
g = Function(V).interpolate(u_exact)

F = inner(grad(u - u_exact), grad(test)) * dx
bcs = [DirichletBC(V, g, 1), DirichletBC(V, u_exact, 2)]

problem = NonlinearVariationalProblem(F, u, bcs=bcs, restrict=True)
solver = NonlinearVariationalSolver(problem, solver_parameters={
"snes_type": "ksponly",
"ksp_type": "cg",
"ksp_rtol": 1E-10,
"ksp_max_it": 10,
"ksp_monitor": None,
"pc_type": "mg",
"mg_levels_ksp_type": "chebyshev",
"mg_levels_pc_type": "jacobi",
"mg_coarse_pc_type": "lu"})
solver.solve()

assert errornorm(u_exact, u) < 1E-10
Loading