Skip to content

Commit 50a90e9

Browse files
Submesh on subcommunicator (#4761)
--------- Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
1 parent 51bbd96 commit 50a90e9

File tree

6 files changed

+116
-12
lines changed

6 files changed

+116
-12
lines changed

firedrake/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# the specific version, here we are more permissive. This is to catch the
44
# case where users don't update their PETSc for a really long time or
55
# accidentally install a too-new release that isn't yet supported.
6+
# TODO RELEASE set to ">=3.25"
67
PETSC_SUPPORTED_VERSIONS = ">=3.23.0"
78

89

firedrake/cython/dmcommon.pyx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3819,7 +3819,8 @@ def submesh_create(PETSc.DM dm,
38193819
PetscInt subdim,
38203820
label_name,
38213821
PetscInt label_value,
3822-
PetscBool ignore_label_halo):
3822+
PetscBool ignore_label_halo,
3823+
comm=None):
38233824
"""Create submesh.
38243825
38253826
Parameters
@@ -3834,12 +3835,12 @@ def submesh_create(PETSc.DM dm,
38343835
Value in the label
38353836
ignore_label_halo : bool
38363837
If labeled points in the halo are ignored.
3838+
comm : PETSc.Comm | None
3839+
An optional sub-communicator to define the submesh.
38373840
38383841
"""
38393842
cdef:
3840-
PETSc.DM subdm = PETSc.DMPlex()
38413843
PETSc.DMLabel label, temp_label
3842-
PETSc.SF ownership_transfer_sf = PETSc.SF()
38433844
char *temp_label_name = <char *>"firedrake_submesh_temp_label"
38443845
PetscInt pStart, pEnd, p, i, stratum_size
38453846
PETSc.PetscIS stratum_is = NULL
@@ -3863,7 +3864,11 @@ def submesh_create(PETSc.DM dm,
38633864
CHKERR(ISRestoreIndices(stratum_is, &stratum_indices))
38643865
CHKERR(ISDestroy(&stratum_is))
38653866
# Make submesh using temp_label.
3866-
CHKERR(DMPlexFilter(dm.dm, temp_label.dmlabel, label_value, ignore_label_halo, PETSC_TRUE, &ownership_transfer_sf.sf, &subdm.dm))
3867+
subdm, ownership_transfer_sf = dm.filter(label=temp_label,
3868+
value=label_value,
3869+
ignoreHalo=ignore_label_halo,
3870+
sanitizeSubMesh=PETSC_TRUE,
3871+
comm=comm)
38673872
# Destroy temp_label.
38683873
dm.removeLabel(temp_label_name)
38693874
subdm.removeLabel(temp_label_name)
@@ -3901,7 +3906,7 @@ def submesh_correct_entity_classes(PETSc.DM dm,
39013906
DMLabel lbl_core, lbl_owned, lbl_ghost
39023907
PetscBool has
39033908

3904-
if dm.comm.size == 1:
3909+
if subdm.comm.size == 1:
39053910
return
39063911
CHKERR(DMPlexGetChart(dm.dm, &pStart, &pEnd))
39073912
CHKERR(DMPlexGetChart(subdm.dm, &subpStart, &subpEnd))

firedrake/cython/petschdr.pxi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ cdef extern from "petscdmplex.h" nogil:
7676
PetscErrorCode DMPlexLabelComplete(PETSc.PetscDM, PETSc.PetscDMLabel)
7777
PetscErrorCode DMPlexDistributeOverlap(PETSc.PetscDM,PetscInt,PETSc.PetscSF*,PETSc.PetscDM*)
7878

79-
PetscErrorCode DMPlexFilter(PETSc.PetscDM,PETSc.PetscDMLabel,PetscInt,PetscBool,PetscBool,PETSc.PetscSF*,PETSc.PetscDM*)
8079
PetscErrorCode DMPlexGetSubpointIS(PETSc.PetscDM,PETSc.PetscIS*)
8180
PetscErrorCode DMPlexGetSubpointMap(PETSc.PetscDM,PETSc.PetscDMLabel*)
8281
PetscErrorCode DMPlexSetSubpointMap(PETSc.PetscDM,PETSc.PetscDMLabel)

firedrake/mesh.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3012,7 +3012,7 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):
30123012

30133013

30143014
def make_mesh_from_mesh_topology(topology, name, tolerance=0.5):
3015-
"""Make mesh from tpology.
3015+
"""Make mesh from topology.
30163016
30173017
Parameters
30183018
----------
@@ -4734,7 +4734,7 @@ def SubDomainData(geometric_expr):
47344734
return op2.Subset(m.cell_set, indices)
47354735

47364736

4737-
def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None):
4737+
def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None, ignore_halo=False, reorder=True, comm=None):
47384738
"""Construct a submesh from a given mesh.
47394739
47404740
Parameters
@@ -4743,12 +4743,20 @@ def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None):
47434743
Parent mesh (`MeshGeometry`).
47444744
subdim : int
47454745
Topological dimension of the submesh.
4746-
subdomain_id : int
4746+
subdomain_id : int | None
47474747
Subdomain ID representing the submesh.
4748-
label_name : str
4748+
`None` defines the submesh owned by the sub-communicator.
4749+
label_name : str | None
47494750
Name of the label to search ``subdomain_id`` in.
4750-
name : str
4751+
name : str | None
47514752
Name of the submesh.
4753+
ignore_halo : bool
4754+
Whether to exclude the halo from the submesh.
4755+
reorder : bool
4756+
Whether to reorder the mesh entities.
4757+
comm : PETSc.Comm | None
4758+
An optional sub-communicator to define the submesh.
4759+
By default, the submesh is defined on `mesh.comm`.
47524760
47534761
Returns
47544762
-------
@@ -4817,15 +4825,27 @@ def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None):
48174825
label_name = dmcommon.CELL_SETS_LABEL
48184826
elif subdim == dim - 1:
48194827
label_name = dmcommon.FACE_SETS_LABEL
4828+
if subdomain_id is None:
4829+
# Filter the plex with PETSc's default label (cells owned by comm)
4830+
if label_name != dmcommon.CELL_SETS_LABEL:
4831+
raise ValueError("subdomain_id == None requires label_name == CELL_SETS_LABEL.")
4832+
subplex, sf = plex.filter(sanitizeSubMesh=True, ignoreHalo=ignore_halo, comm=comm)
4833+
dmcommon.submesh_update_facet_labels(plex, subplex)
4834+
dmcommon.submesh_correct_entity_classes(plex, subplex, sf)
4835+
else:
4836+
subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, ignore_halo, comm=comm)
4837+
4838+
comm = comm or mesh.comm
48204839
name = name or _generate_default_submesh_name(mesh.name)
4821-
subplex = dmcommon.submesh_create(plex, subdim, label_name, subdomain_id, False)
48224840
subplex.setName(_generate_default_mesh_topology_name(name))
48234841
if subplex.getDimension() != subdim:
48244842
raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})")
48254843
submesh = Mesh(
48264844
subplex,
48274845
submesh_parent=mesh,
48284846
name=name,
4847+
comm=comm,
4848+
reorder=reorder,
48294849
distribution_parameters={
48304850
"partition": False,
48314851
"overlap_type": (DistributedMeshOverlapType.NONE, 0),

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
# Ensure that the PETSc getting linked against is compatible
23+
# TODO RELEASE set to ">=3.25"
2324
petsctools.init(version_spec=">=3.23.0")
2425
import petsc4py
2526

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
import numpy as np
3+
from firedrake import *
4+
from firedrake.petsc import PETSc
5+
6+
7+
def assert_local_equality(A, B):
8+
i0, j0, v0 = A.getValuesCSR()
9+
i1, j1, v1 = B.getValuesCSR()
10+
j0 -= A.createVecs()[0].getOwnershipRange()[0]
11+
j1 -= B.createVecs()[0].getOwnershipRange()[0]
12+
assert np.array_equal(i0, i1)
13+
assert np.array_equal(j0, j1)
14+
assert np.allclose(v0, v1)
15+
16+
17+
@pytest.mark.parallel([1, 3])
18+
def test_create_submesh_comm_self():
19+
subdomain_id = None
20+
nx = 4
21+
mesh = UnitSquareMesh(nx, nx, quadrilateral=True, reorder=False)
22+
submesh = Submesh(mesh, mesh.topological_dimension, subdomain_id, ignore_halo=True, reorder=False, comm=COMM_SELF)
23+
assert submesh.submesh_parent is mesh
24+
assert submesh.comm.size == 1
25+
assert submesh.cell_set.size == mesh.cell_set.size
26+
assert np.allclose(mesh.coordinates.dat.data_ro, submesh.coordinates.dat.data_ro)
27+
28+
29+
@pytest.mark.parallel([1, 3])
30+
def test_assemble_submesh_comm_self():
31+
subdomain_id = None
32+
nx = 6
33+
ny = 5
34+
px = -np.cos(np.linspace(0, np.pi, nx))
35+
py = -np.cos(np.linspace(0, np.pi, ny))
36+
mesh = TensorRectangleMesh(px, py, reorder=False)
37+
submesh = Submesh(mesh, mesh.topological_dimension, subdomain_id, ignore_halo=True, reorder=False, comm=COMM_SELF)
38+
39+
Vsub = FunctionSpace(submesh, "DG", 0)
40+
Asub = assemble(inner(TrialFunction(Vsub), TestFunction(Vsub))*dx)
41+
42+
V = FunctionSpace(mesh, "DG", 0)
43+
A = assemble(inner(TrialFunction(V), TestFunction(V))*dx)
44+
assert_local_equality(A.petscmat, Asub.petscmat)
45+
46+
47+
@pytest.mark.parallel([1, 3])
48+
@pytest.mark.parametrize("label", ["some", "all"])
49+
def test_label_submesh_comm_self(label):
50+
subdomain_id = 999
51+
nx = 8
52+
mesh = UnitSquareMesh(nx, nx, reorder=False)
53+
54+
M = FunctionSpace(mesh, "DG", 0)
55+
marker = Function(M)
56+
if label == "some":
57+
x, y = SpatialCoordinate(mesh)
58+
marker.interpolate(conditional(Or(x > 0.5, y > 0.5), 1, 0))
59+
elif label == "all":
60+
marker.assign(1)
61+
else:
62+
raise ValueError(f"Unrecognized label {label}")
63+
64+
mesh = RelabeledMesh(mesh, [marker], [subdomain_id])
65+
submesh = Submesh(mesh, mesh.topological_dimension, subdomain_id, ignore_halo=True, reorder=False, comm=COMM_SELF)
66+
67+
Vsub = FunctionSpace(submesh, "DG", 0)
68+
Asub = assemble(inner(TrialFunction(Vsub), TestFunction(Vsub)) * dx)
69+
70+
V = FunctionSpace(mesh, "DG", 0)
71+
A = assemble(inner(TrialFunction(V), TestFunction(V)) * dx)
72+
if label == "all":
73+
assert_local_equality(A.petscmat, Asub.petscmat)
74+
else:
75+
lgmap = V.dof_dset.lgmap
76+
indices = PETSc.IS().createGeneral(lgmap.apply(np.flatnonzero(marker.dat.data).astype(PETSc.IntType)))
77+
Amat = A.petscmat.createSubMatrix(indices, indices)
78+
assert_local_equality(Amat, Asub.petscmat)

0 commit comments

Comments
 (0)