Skip to content

Commit d145b59

Browse files
committed
add type hints to mesh partitioning
1 parent 74425fe commit d145b59

File tree

1 file changed

+56
-37
lines changed

1 file changed

+56
-37
lines changed

meshmode/mesh/processing.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from functools import reduce
2424
from numbers import Real
25-
from typing import Optional, Union
25+
from typing import Optional, Union, Any, Tuple, Dict, List, Set
2626

2727
from dataclasses import dataclass
2828

@@ -32,7 +32,10 @@
3232
import modepy as mp
3333

3434
from meshmode.mesh import (
35+
MeshElementGroup,
36+
Mesh,
3537
BTAG_PARTITION,
38+
PartID,
3639
InteriorAdjacencyGroup,
3740
BoundaryAdjacencyGroup,
3841
InterPartAdjacencyGroup
@@ -81,14 +84,17 @@ def find_group_indices(groups, meshwide_elems):
8184
# {{{ partition_mesh
8285

8386
def _compute_global_elem_to_part_elem(
84-
nelements, part_id_to_elements, part_id_to_part_index, element_id_dtype):
87+
nelements: int,
88+
part_id_to_elements: Dict[PartID, np.ndarray],
89+
part_id_to_part_index: Dict[PartID, int],
90+
element_id_dtype: Any) -> np.ndarray:
8591
"""
8692
Create a map from global element index to part-wide element index for a set of
8793
parts.
8894
8995
:arg nelements: The number of elements in the global mesh.
90-
:arg part_id_to_elements: A :class:`dict` mapping part identifiers to
91-
sets of elements.
96+
:arg part_id_to_elements: A :class:`dict` mapping a part identifier to
97+
a sorted :class:`numpy.ndarray` of elements.
9298
:arg part_id_to_part_index: A mapping from part identifiers to indices in
9399
the range ``[0, num_parts)``.
94100
:arg element_id_dtype: The element index data type.
@@ -107,7 +113,10 @@ def _compute_global_elem_to_part_elem(
107113
return global_elem_to_part_elem
108114

109115

110-
def _filter_mesh_groups(mesh, selected_elements, vertex_id_dtype):
116+
def _filter_mesh_groups(
117+
mesh: Mesh,
118+
selected_elements: np.ndarray,
119+
vertex_id_dtype: Any) -> Tuple[List, np.ndarray]:
111120
"""
112121
Create new mesh groups containing a selected subset of elements.
113122
@@ -173,7 +182,10 @@ def _filter_mesh_groups(mesh, selected_elements, vertex_id_dtype):
173182

174183

175184
def _get_connected_parts(
176-
mesh, part_id_to_part_index, global_elem_to_part_elem, self_part_id):
185+
mesh: Mesh,
186+
part_id_to_part_index: Dict[PartID, int],
187+
global_elem_to_part_elem: np.ndarray,
188+
self_part_id: PartID) -> "Set[PartID]":
177189
"""
178190
Find the parts that are connected to the current part.
179191
@@ -218,8 +230,12 @@ def _get_connected_parts(
218230
if part_index in connected_part_indices}
219231

220232

221-
def _create_self_to_self_adjacency_groups(mesh, global_elem_to_part_elem,
222-
self_part_index, self_mesh_groups, self_mesh_group_elem_base):
233+
def _create_self_to_self_adjacency_groups(
234+
mesh: Mesh,
235+
global_elem_to_part_elem: np.ndarray,
236+
self_part_index: int,
237+
self_mesh_groups: List[MeshElementGroup],
238+
self_mesh_group_elem_base: List[int]) -> List[List[InteriorAdjacencyGroup]]:
223239
r"""
224240
Create self-to-self facial adjacency groups for a partitioned mesh.
225241
@@ -229,9 +245,9 @@ def _create_self_to_self_adjacency_groups(mesh, global_elem_to_part_elem,
229245
:func:`_compute_global_elem_to_part_elem`` for details.
230246
:arg self_part_index: The index of the part currently being created, in the
231247
range ``[0, num_parts)``.
232-
:arg self_mesh_groups: An array of :class:`~meshmode.mesh.MeshElementGroup`
248+
:arg self_mesh_groups: A list of :class:`~meshmode.mesh.MeshElementGroup`
233249
instances representing the partitioned mesh groups.
234-
:arg self_mesh_group_elem_base: An array containing the starting part-wide
250+
:arg self_mesh_group_elem_base: A list containing the starting part-wide
235251
element index for each group in *self_mesh_groups*.
236252
237253
:returns: A list of lists of `~meshmode.mesh.InteriorAdjacencyGroup` instances
@@ -283,8 +299,13 @@ def _create_self_to_self_adjacency_groups(mesh, global_elem_to_part_elem,
283299

284300

285301
def _create_self_to_other_adjacency_groups(
286-
mesh, part_id_to_part_index, global_elem_to_part_elem, self_part_id,
287-
self_mesh_groups, self_mesh_group_elem_base, connected_parts):
302+
mesh: Mesh,
303+
part_id_to_part_index: Dict[PartID, int],
304+
global_elem_to_part_elem: np.ndarray,
305+
self_part_id: PartID,
306+
self_mesh_groups: List[MeshElementGroup],
307+
self_mesh_group_elem_base: List[int],
308+
connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]:
288309
"""
289310
Create self-to-other adjacency groups for the partitioned mesh.
290311
@@ -295,9 +316,9 @@ def _create_self_to_other_adjacency_groups(
295316
indices to part indices and part-wide element indices. See
296317
:func:`_compute_global_elem_to_part_elem`` for details.
297318
:arg self_part_id: The identifier of the part currently being created.
298-
:arg self_mesh_groups: An array of `~meshmode.mesh.MeshElementGroup` instances
319+
:arg self_mesh_groups: A list of `~meshmode.mesh.MeshElementGroup` instances
299320
representing the partitioned mesh groups.
300-
:arg self_mesh_group_elem_base: An array containing the starting part-wide
321+
:arg self_mesh_group_elem_base: A list containing the starting part-wide
301322
element index for each group in *self_mesh_groups*.
302323
:arg connected_parts: A :class:`set` containing the parts connected to
303324
the current one.
@@ -358,8 +379,12 @@ def _create_self_to_other_adjacency_groups(
358379
return self_to_other_adj_groups
359380

360381

361-
def _create_boundary_groups(mesh, global_elem_to_part_elem, self_part_index,
362-
self_mesh_groups, self_mesh_group_elem_base):
382+
def _create_boundary_groups(
383+
mesh: Mesh,
384+
global_elem_to_part_elem: np.ndarray,
385+
self_part_index: PartID,
386+
self_mesh_groups: List[MeshElementGroup],
387+
self_mesh_group_elem_base: List[int]) -> List[List[BoundaryAdjacencyGroup]]:
363388
"""
364389
Create boundary groups for partitioned mesh.
365390
@@ -369,9 +394,9 @@ def _create_boundary_groups(mesh, global_elem_to_part_elem, self_part_index,
369394
:func:`_compute_global_elem_to_part_elem`` for details.
370395
:arg self_part_index: The index of the part currently being created, in the
371396
range ``[0, num_parts)``.
372-
:arg self_mesh_groups: An array of `~meshmode.mesh.MeshElementGroup` instances
397+
:arg self_mesh_groups: A list of `~meshmode.mesh.MeshElementGroup` instances
373398
representing the partitioned mesh groups.
374-
:arg self_mesh_group_elem_base: An array containing the starting part-wide
399+
:arg self_mesh_group_elem_base: A list containing the starting part-wide
375400
element index for each group in *self_mesh_groups*.
376401
377402
:returns: A list of lists of `~meshmode.mesh.BoundaryAdjacencyGroup` instances
@@ -411,11 +436,14 @@ def _create_boundary_groups(mesh, global_elem_to_part_elem, self_part_index,
411436
return bdry_adj_groups
412437

413438

414-
def _get_mesh_part(mesh, part_id_to_elements, self_part_id):
439+
def _get_mesh_part(
440+
mesh: Mesh,
441+
part_id_to_elements: Dict[PartID, np.ndarray],
442+
self_part_id: PartID) -> Mesh:
415443
"""
416444
:arg mesh: A :class:`~meshmode.mesh.Mesh` to be partitioned.
417-
:arg part_id_to_elements: A :class:`dict` mapping part identifiers to
418-
sets of elements.
445+
:arg part_id_to_elements: A :class:`dict` mapping a part identifier to
446+
a sorted :class:`numpy.ndarray` of elements.
419447
:arg self_part_id: The part identifier of the mesh to return.
420448
421449
:returns: A :class:`~meshmode.mesh.Mesh` containing a part of *mesh*.
@@ -432,9 +460,7 @@ def _get_mesh_part(mesh, part_id_to_elements, self_part_id):
432460

433461
part_id_to_part_index = {
434462
part_id: part_index
435-
for part_id, part_index in zip(
436-
part_id_to_elements.keys(),
437-
range(len(part_id_to_elements)))}
463+
for part_index, part_id in enumerate(part_id_to_elements.keys())}
438464

439465
global_elem_to_part_elem = _compute_global_elem_to_part_elem(
440466
mesh.nelements, part_id_to_elements, part_id_to_part_index,
@@ -480,21 +506,20 @@ def _get_mesh_part(mesh, part_id_to_elements, self_part_id):
480506
+ boundary_adj_groups[igrp]
481507
for igrp in range(len(self_mesh_groups))]
482508

483-
from meshmode.mesh import Mesh
484-
self_mesh = Mesh(
509+
return Mesh(
485510
self_vertices,
486511
self_mesh_groups,
487512
facial_adjacency_groups=self_facial_adj_groups,
488513
is_conforming=mesh.is_conforming)
489514

490-
return self_mesh
491-
492515

493-
def partition_mesh(mesh, part_id_to_elements):
516+
def partition_mesh(
517+
mesh: Mesh,
518+
part_id_to_elements: Dict[PartID, np.ndarray]) -> "Dict[PartID, Mesh]":
494519
"""
495520
:arg mesh: A :class:`~meshmode.mesh.Mesh` to be partitioned.
496-
:arg part_id_to_elements: A :class:`dict` mapping part identifiers to sets of
497-
elements.
521+
:arg part_id_to_elements: A :class:`dict` mapping a part identifier to
522+
a sorted :class:`numpy.ndarray` of elements.
498523
499524
:returns: A :class:`dict` mapping part identifiers to instances of
500525
:class:`~meshmode.mesh.Mesh` that represent the corresponding part of
@@ -696,8 +721,6 @@ def perform_flips(mesh, flip_flags, skip_tests=False):
696721

697722
flip_flags = flip_flags.astype(bool)
698723

699-
from meshmode.mesh import Mesh
700-
701724
new_groups = []
702725
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
703726
grp_flip_flags = flip_flags[base_element_nr:base_element_nr + grp.nelements]
@@ -835,7 +858,6 @@ def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False):
835858

836859
# }}}
837860

838-
from meshmode.mesh import Mesh
839861
return Mesh(vertices, new_groups, skip_tests=skip_tests,
840862
nodal_adjacency=nodal_adjacency,
841863
facial_adjacency_groups=facial_adjacency_groups,
@@ -893,7 +915,6 @@ def split_mesh_groups(mesh, element_flags, return_subgroup_mapping=False):
893915
element_nr_base=None, node_nr_base=None,
894916
))
895917

896-
from meshmode.mesh import Mesh
897918
mesh = Mesh(
898919
vertices=mesh.vertices,
899920
groups=new_groups,
@@ -1183,8 +1204,6 @@ def glue_mesh_boundaries(mesh, bdry_pair_mappings_and_tols, *, use_tree=None):
11831204
_match_boundary_faces(mesh, mapping, tol, use_tree=use_tree)
11841205
for mapping, tol in bdry_pair_mappings_and_tols]
11851206

1186-
from meshmode.mesh import InteriorAdjacencyGroup, BoundaryAdjacencyGroup
1187-
11881207
facial_adjacency_groups = []
11891208

11901209
for igrp, old_fagrp_list in enumerate(mesh.facial_adjacency_groups):

0 commit comments

Comments
 (0)