Skip to content

Commit 6fa0338

Browse files
committed
Fixes
* Do not try and distribute when serial * Use a slightly different approach to building a topology_dm (to preserve labels)
1 parent d905c57 commit 6fa0338

File tree

4 files changed

+37
-25
lines changed

4 files changed

+37
-25
lines changed

firedrake/cython/dmcommon.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3972,6 +3972,7 @@ def dmplex_migrate(PETSc.DM dm, PETSc.SF sf) -> PETSc.DM:
39723972
return migrated_dm
39733973

39743974

3975+
# not used currently
39753976
def densify_sf(PETSc.DM topology_dm, PETSc.SF sparse_sf) -> PETSc.SF:
39763977
cdef:
39773978
PETSc.SF dense_sf
@@ -4008,10 +4009,16 @@ def densify_sf(PETSc.DM topology_dm, PETSc.SF sparse_sf) -> PETSc.SF:
40084009
return dense_sf
40094010

40104011

4012+
# not used currently
40114013
def dmplex_create_overlap_migration_sf(PETSc.DM topology_dm, PETSc.SF overlap_sf) -> PETSc.SF:
40124014
cdef:
40134015
PETSc.SF migration_sf
40144016

40154017
migration_sf = PETSc.SF().create(comm=topology_dm.comm)
40164018
CHKERR(DMPlexCreateOverlapMigrationSF(topology_dm.dm, overlap_sf.sf, &migration_sf.sf))
40174019
return migration_sf
4020+
4021+
4022+
# not used currently
4023+
def set_label(PETSc.DM dm, PETSc.DMLabel label) -> None:
4024+
CHKERR(DMSetLabel(dm.dm, label.dmlabel))

firedrake/cython/petschdr.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ cdef extern from "petscdmlabel.h" nogil:
7777
cdef extern from "petscdm.h" nogil:
7878
int DMCreateLabel(PETSc.PetscDM,char[])
7979
int DMGetLabel(PETSc.PetscDM,char[],DMLabel*)
80+
int DMSetLabel(PETSc.PetscDM,PETSc.PetscDMLabel)
8081
int DMGetPointSF(PETSc.PetscDM,PETSc.PetscSF*)
8182
int DMSetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt)
8283
int DMGetLabelValue(PETSc.PetscDM,char[],PetscInt,PetscInt*)

firedrake/mesh.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,16 +2318,11 @@ def callback(self):
23182318
del self._callback
23192319
# Finish the initialisation of mesh topology
23202320
self.topology.init()
2321-
# geometry_dm = self._input_geometry_dm.migrate(self.topology.overlap_sf)
2322-
2323-
# make the overlap SF dense
2324-
# dense_overlap_sf = dmcommon.densify_sf(self.topology.topology_dm, self.topology.overlap_sf)
2325-
# breakpoint()
2326-
# migration_sf = dmcommon.dmplex_create_overlap_migration_sf(self.topology.topology_dm, self.topology.overlap_sf)
2327-
# geometry_dm = dmcommon.dmplex_migrate(self._input_geometry_dm, dense_overlap_sf)
2328-
# geometry_dm = dmcommon.dmplex_migrate(self._input_geometry_dm, self.topology.overlap_sf)
2329-
# geometry_dm = dmcommon.dmplex_migrate(self._input_geometry_dm, migration_sf)
2330-
geometry_dm = dmcommon.dmplex_migrate(self._input_geometry_dm, self.topology.sfBC)
2321+
if self.comm.size > 1:
2322+
assert self.topology.sfBC is not None
2323+
geometry_dm = dmcommon.dmplex_migrate(self._input_geometry_dm, self.topology.sfBC)
2324+
else:
2325+
geometry_dm = self._input_geometry_dm
23312326

23322327
coordinates_fs = functionspace.FunctionSpace(self.topology, self.ufl_coordinate_element())
23332328

@@ -3128,22 +3123,32 @@ def Mesh(meshfile, **kwargs):
31283123
% (meshfile, ext[1:]))
31293124
plex.setName(_generate_default_mesh_topology_name(name))
31303125

3126+
# TODO: Push other labels onto the coordinate DM, this should be done
3127+
# by DMPlex
3128+
# NOTE: OR we instead set the plex as the coordinate DM by dropping
3129+
# coordinates
3130+
# topology_dm = plex.getCoordinateDM()
3131+
#
3132+
# face_sets_label = plex.getLabel(dmcommon.FACE_SETS_LABEL)
3133+
# topology_dm.createLabel(dmcommon.FACE_SETS_LABEL)
3134+
# dmcommon.set_label(topology_dm, face_sets_label)
3135+
topology_dm = plex.clone()
3136+
topology_dm.setCoordinates(plex.getCoordinateDM().getCoordinates())
3137+
31313138
# Create mesh topology
31323139
# Pass the coordinate DM because we only want the topology here
3133-
topology = MeshTopology(plex.getCoordinateDM(), name=plex.getName(), reorder=reorder,
3140+
# topology = MeshTopology(plex.getCoordinateDM(), name=plex.getName(), reorder=reorder,
3141+
topology = MeshTopology(topology_dm, name=plex.getName(), reorder=reorder,
31343142
distribution_parameters=distribution_parameters,
31353143
distribution_name=kwargs.get("distribution_name"),
31363144
permutation_name=kwargs.get("permutation_name"),
31373145
comm=user_comm)
31383146

3139-
# distributed_plex = dmcommon.dmplex_migrate(plex, topology.sfBC)
3140-
distributed_plex = plex
3141-
31423147
if netgen and isinstance(meshfile, netgen.libngpy._meshing.Mesh):
31433148
netgen_firedrake_mesh.createFromTopology(topology, name=name, comm=user_comm)
31443149
mesh = netgen_firedrake_mesh.firedrakeMesh
31453150
else:
3146-
mesh = make_mesh_from_mesh_topology(topology, distributed_plex, name)
3151+
mesh = make_mesh_from_mesh_topology(topology, plex, name)
31473152
mesh._tolerance = tolerance
31483153
return mesh
31493154

firedrake/utility_meshes.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -736,13 +736,12 @@ def TensorRectangleMesh(
736736
)
737737

738738
# mark boundary facets
739-
topology_dm = plex.getCoordinateDM()
740-
topology_dm.createLabel(dmcommon.FACE_SETS_LABEL)
741-
topology_dm.markBoundaryFaces("boundary_faces")
739+
plex.createLabel(dmcommon.FACE_SETS_LABEL)
740+
plex.markBoundaryFaces("boundary_faces")
742741
coords = plex.getCoordinates()
743742
coord_sec = plex.getCoordinateSection()
744-
if topology_dm.getStratumSize("boundary_faces", 1) > 0:
745-
boundary_faces = topology_dm.getStratumIS("boundary_faces", 1).getIndices()
743+
if plex.getStratumSize("boundary_faces", 1) > 0:
744+
boundary_faces = plex.getStratumIS("boundary_faces", 1).getIndices()
746745
xtol = 0.5 * min(xcoords[1] - xcoords[0], xcoords[-1] - xcoords[-2])
747746
ytol = 0.5 * min(ycoords[1] - ycoords[0], ycoords[-1] - ycoords[-2])
748747
x0 = xcoords[0]
@@ -752,14 +751,14 @@ def TensorRectangleMesh(
752751
for face in boundary_faces:
753752
face_coords = plex.vecGetClosure(coord_sec, coords, face)
754753
if abs(face_coords[0] - x0) < xtol and abs(face_coords[2] - x0) < xtol:
755-
topology_dm.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 1)
754+
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 1)
756755
if abs(face_coords[0] - x1) < xtol and abs(face_coords[2] - x1) < xtol:
757-
topology_dm.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 2)
756+
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 2)
758757
if abs(face_coords[1] - y0) < ytol and abs(face_coords[3] - y0) < ytol:
759-
topology_dm.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 3)
758+
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 3)
760759
if abs(face_coords[1] - y1) < ytol and abs(face_coords[3] - y1) < ytol:
761-
topology_dm.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 4)
762-
topology_dm.removeLabel("boundary_faces")
760+
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 4)
761+
plex.removeLabel("boundary_faces")
763762
m = mesh.Mesh(
764763
plex,
765764
reorder=reorder,

0 commit comments

Comments
 (0)