Skip to content

Commit 20197bc

Browse files
committed
dmswarm update
1 parent cf84491 commit 20197bc

File tree

4 files changed

+25
-25
lines changed

4 files changed

+25
-25
lines changed

firedrake/cython/dmcommon.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,6 +2120,8 @@ def mark_entity_classes_using_cell_dm(PETSc.DM swarm):
21202120
PetscInt nswarmCells, swarmCell, blocksize
21212121
PetscInt *swarmParentCells = NULL
21222122
PetscDataType ctype = PETSC_DATATYPE_UNKNOWN
2123+
const char *cellid = NULL
2124+
PETSc.PetscDMSwarmCellDM celldm
21232125

21242126
plex = swarm.getCellDM()
21252127
get_height_stratum(plex.dm, 0, &cStart, &cEnd)
@@ -2145,14 +2147,16 @@ def mark_entity_classes_using_cell_dm(PETSc.DM swarm):
21452147
for ilabel, op2class in enumerate([b"pyop2_core", b"pyop2_owned", b"pyop2_ghost"]):
21462148
CHKERR(DMCreateLabel(swarm.dm, op2class))
21472149
CHKERR(DMGetLabel(swarm.dm, op2class, &swarm_labels[ilabel]))
2148-
CHKERR(DMSwarmGetField(swarm.dm, b"DMSwarm_cellid", &blocksize, &ctype, <void**>&swarmParentCells))
2150+
CHKERR(DMSwarmGetCellDMActive(swarm.dm, &celldm))
2151+
CHKERR(DMSwarmCellDMGetCellID(celldm, &cellid))
2152+
CHKERR(DMSwarmGetField(swarm.dm, cellid, &blocksize, &ctype, <void**> &swarmParentCells))
21492153
assert ctype == PETSC_INT
21502154
assert blocksize == 1
21512155
CHKERR(DMSwarmGetLocalSize(swarm.dm, &nswarmCells))
21522156
for swarmCell in range(nswarmCells):
21532157
plex_cell_class = plex_cell_classes[swarmParentCells[swarmCell] - cStart]
21542158
CHKERR(DMLabelSetValue(swarm_labels[plex_cell_class], swarmCell, label_value))
2155-
CHKERR(DMSwarmRestoreField(swarm.dm, b"DMSwarm_cellid", &blocksize, &ctype, <void**>&swarmParentCells))
2159+
CHKERR(DMSwarmRestoreField(swarm.dm, cellid, &blocksize, &ctype, <void**> &swarmParentCells))
21562160
CHKERR(PetscFree(plex_cell_classes))
21572161

21582162

firedrake/cython/petschdr.pxi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ cdef extern from "petscdm.h" nogil:
7979
cdef extern from "petscdmswarm.h" nogil:
8080
int DMSwarmGetLocalSize(PETSc.PetscDM,PetscInt*)
8181
int DMSwarmGetCellDM(PETSc.PetscDM, PETSc.PetscDM*)
82+
int DMSwarmGetCellDMActive(PETSc.PetscDM, PETSc.PetscDMSwarmCellDM*)
83+
int DMSwarmCellDMGetCellID(PETSc.PetscDMSwarmCellDM, const char *[])
8284
int DMSwarmGetField(PETSc.PetscDM,const char[],PetscInt*,PetscDataType*,void**)
8385
int DMSwarmRestoreField(PETSc.PetscDM,const char[],PetscInt*,PetscDataType*,void**)
8486

firedrake/mesh.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,14 +1966,15 @@ def _renumber_entities(self, reorder):
19661966
if reorder:
19671967
swarm = self.topology_dm
19681968
parent = self._parent_mesh.topology_dm
1969-
swarm_parent_cell_nums = swarm.getField("DMSwarm_cellid").ravel()
1969+
cell_id_name = swarm.getCellDMActive().getCellID()
1970+
swarm_parent_cell_nums = swarm.getField(cell_id_name).ravel()
19701971
parent_renum = self._parent_mesh._dm_renumbering.getIndices()
19711972
pStart, _ = parent.getChart()
19721973
parent_renum_inv = np.empty_like(parent_renum)
19731974
parent_renum_inv[parent_renum - pStart] = np.arange(len(parent_renum))
19741975
# Use kind = 'stable' to make the ordering deterministic.
19751976
perm = np.argsort(parent_renum_inv[swarm_parent_cell_nums - pStart], kind='stable').astype(IntType)
1976-
swarm.restoreField("DMSwarm_cellid")
1977+
swarm.restoreField(cell_id_name)
19771978
perm_is = PETSc.IS().create(comm=swarm.comm)
19781979
perm_is.setType("general")
19791980
perm_is.setIndices(perm)
@@ -3557,11 +3558,9 @@ def _pic_swarm_in_mesh(
35573558
#. ``parentcellextrusionheight`` which contains the extrusion height of
35583559
the immersed vertex in the parent mesh cell.
35593560
3560-
Another three are required for proper functioning of the DMSwarm:
3561+
Another two are required for proper functioning of the DMSwarm:
35613562
35623563
#. ``DMSwarmPIC_coor`` which contains the coordinates of the point.
3563-
#. ``DMSwarm_cellid`` the DMPlex cell within which the DMSwarm point is
3564-
located.
35653564
#. ``DMSwarm_rank``: the MPI rank which owns the DMSwarm point.
35663565
35673566
.. note::
@@ -3794,7 +3793,6 @@ def _dmswarm_create(
37943793
# These are created by default for a PIC DMSwarm
37953794
default_fields = [
37963795
("DMSwarmPIC_coor", gdim, RealType),
3797-
("DMSwarm_cellid", 1, IntType),
37983796
("DMSwarm_rank", 1, IntType),
37993797
]
38003798

@@ -3853,12 +3851,6 @@ def _dmswarm_create(
38533851
# Set to Particle In Cell (PIC) type
38543852
if not isinstance(plex, PETSc.DMSwarm):
38553853
swarm.setType(PETSc.DMSwarm.Type.PIC)
3856-
else:
3857-
# This doesn't work where we embed a DMSwarm in a DMSwarm, instead
3858-
# we register some default fields manually
3859-
for name, size, dtype in default_fields:
3860-
if name == "DMSwarmPIC_coor" or name == "DMSwarm_cellid":
3861-
swarm.registerField(name, size, dtype=dtype)
38623854

38633855
# Register any fields
38643856
for name, size, dtype in swarm.default_extra_fields + swarm.other_fields:
@@ -3872,14 +3864,15 @@ def _dmswarm_create(
38723864
# Add point coordinates. This amounts to our own implementation of
38733865
# DMSwarmSetPointCoordinates because Firedrake's mesh coordinate model
38743866
# doesn't always exactly coincide with that of DMPlex: in most cases the
3875-
# plex_parent_cell_nums (DMSwarm_cellid field) and parent_cell_nums
3876-
# (parentcellnum field), the latter being the numbering used by firedrake,
3877-
# refer fundamentally to the same cells. For extruded meshes the DMPlex
3878-
# dimension is based on the topological dimension of the base mesh.
3867+
# plex_parent_cell_nums and parent_cell_nums (parentcellnum field), the
3868+
# latter being the numbering used by firedrake, refer fundamentally to the
3869+
# same cells. For extruded meshes the DMPlex dimension is based on the
3870+
# topological dimension of the base mesh.
38793871

38803872
# NOTE ensure that swarm.restoreField is called for each field too!
38813873
swarm_coords = swarm.getField("DMSwarmPIC_coor").reshape((num_vertices, gdim))
3882-
swarm_parent_cell_nums = swarm.getField("DMSwarm_cellid").ravel()
3874+
cell_id_name = swarm.getCellDMActive().getCellID()
3875+
swarm_parent_cell_nums = swarm.getField(cell_id_name).ravel()
38833876
field_parent_cell_nums = swarm.getField("parentcellnum").ravel()
38843877
field_reference_coords = swarm.getField("refcoord").reshape((num_vertices, tdim))
38853878
field_global_index = swarm.getField("globalindex").ravel()
@@ -3903,7 +3896,7 @@ def _dmswarm_create(
39033896
swarm.restoreField("refcoord")
39043897
swarm.restoreField("parentcellnum")
39053898
swarm.restoreField("DMSwarmPIC_coor")
3906-
swarm.restoreField("DMSwarm_cellid")
3899+
swarm.restoreField(cell_id_name)
39073900

39083901
if extruded:
39093902
field_base_parent_cell_nums = swarm.getField("parentcellbasenum").ravel()

tests/firedrake/vertexonly/test_swarm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,9 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
218218
nptslocal = len(localpointcoords)
219219
nptsglobal = MPI.COMM_WORLD.allreduce(nptslocal, op=MPI.SUM)
220220
# Get parent PETSc cell indices on current MPI rank
221-
localparentcellindices = np.copy(swarm.getField("DMSwarm_cellid").ravel())
222-
swarm.restoreField("DMSwarm_cellid")
221+
cell_id = swarm.getCellDMActive().getCellID()
222+
localparentcellindices = np.copy(swarm.getField(cell_id).ravel())
223+
swarm.restoreField(cell_id)
223224

224225
# also get the global coordinate numbering
225226
globalindices = np.copy(swarm.getField("globalindex").ravel())
@@ -242,7 +243,6 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
242243
# Check swarm fields are correct
243244
default_fields = [
244245
("DMSwarmPIC_coor", parentmesh.geometric_dimension(), RealType),
245-
("DMSwarm_cellid", 1, IntType),
246246
("DMSwarm_rank", 1, IntType),
247247
]
248248
default_extra_fields = [
@@ -378,8 +378,9 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos):
378378
):
379379
swarm.setPointCoordinates(localpointcoords, redundant=False,
380380
mode=PETSc.InsertMode.INSERT_VALUES)
381-
petsclocalparentcellindices = np.copy(swarm.getField("DMSwarm_cellid").ravel())
382-
swarm.restoreField("DMSwarm_cellid")
381+
cell_id = swarm.getCellDMActive().getCellID()
382+
petsclocalparentcellindices = np.copy(swarm.getField(cell_id).ravel())
383+
swarm.restoreField(cell_id)
383384
if exclude_halos:
384385
assert np.all(petsclocalparentcellindices == localparentcellindices)
385386
elif parentmesh.comm.size > 1:

0 commit comments

Comments
 (0)