Skip to content

Commit 73c7394

Browse files
committed
add type annotations for adaptive/learner/triangulation.py
1 parent 341d313 commit 73c7394

File tree

1 file changed

+90
-35
lines changed

1 file changed

+90
-35
lines changed

adaptive/learner/triangulation.py

Lines changed: 90 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from collections.abc import Iterable, Sized
44
from itertools import chain, combinations
55
from math import factorial
6+
from typing import Any, Iterator, List, Optional, Set, Tuple, Union
67

78
import numpy as np
89
import scipy.spatial
10+
from numpy import bool_, float64, int32, ndarray
911

1012

11-
def fast_norm(v):
13+
def fast_norm(v: Union[Tuple[float64, float64, float64], ndarray]) -> float:
1214
# notice this method can be even more optimised
1315
if len(v) == 2:
1416
return math.sqrt(v[0] * v[0] + v[1] * v[1])
@@ -17,7 +19,15 @@ def fast_norm(v):
1719
return math.sqrt(np.dot(v, v))
1820

1921

20-
def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
22+
def fast_2d_point_in_simplex(
23+
point: Union[Tuple[int, int], Tuple[float, float], Tuple[float64, float64]],
24+
simplex: Union[
25+
List[Union[Tuple[int, int], Tuple[float64, float64]]],
26+
List[Tuple[float64, float64]],
27+
ndarray,
28+
],
29+
eps: float = 1e-8,
30+
) -> Union[bool, bool_]:
2131
(p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
2232
px, py = point
2333

@@ -31,7 +41,7 @@ def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
3141
return (t >= -eps) and (s + t <= 1 + eps)
3242

3343

34-
def point_in_simplex(point, simplex, eps=1e-8):
44+
def point_in_simplex(point: Any, simplex: Any, eps: float = 1e-8) -> Union[bool, bool_]:
3545
if len(point) == 2:
3646
return fast_2d_point_in_simplex(point, simplex, eps)
3747

@@ -42,7 +52,7 @@ def point_in_simplex(point, simplex, eps=1e-8):
4252
return all(alpha > -eps) and sum(alpha) < 1 + eps
4353

4454

45-
def fast_2d_circumcircle(points):
55+
def fast_2d_circumcircle(points: ndarray) -> Tuple[Tuple[float64, float64], float]:
4656
"""Compute the center and radius of the circumscribed circle of a triangle
4757
4858
Parameters
@@ -78,7 +88,9 @@ def fast_2d_circumcircle(points):
7888
return (x + points[0][0], y + points[0][1]), radius
7989

8090

81-
def fast_3d_circumcircle(points):
91+
def fast_3d_circumcircle(
92+
points: ndarray,
93+
) -> Tuple[Tuple[float64, float64, float64], float]:
8294
"""Compute the center and radius of the circumscribed shpere of a simplex.
8395
8496
Parameters
@@ -118,7 +130,7 @@ def fast_3d_circumcircle(points):
118130
return center, radius
119131

120132

121-
def fast_det(matrix):
133+
def fast_det(matrix: ndarray) -> float64:
122134
matrix = np.asarray(matrix, dtype=float)
123135
if matrix.shape == (2, 2):
124136
return matrix[0][0] * matrix[1][1] - matrix[1][0] * matrix[0][1]
@@ -129,7 +141,13 @@ def fast_det(matrix):
129141
return np.linalg.det(matrix)
130142

131143

132-
def circumsphere(pts):
144+
def circumsphere(
145+
pts: ndarray,
146+
) -> Union[
147+
Tuple[Tuple[float64, float64, float64, float64], float],
148+
Tuple[Tuple[float64, float64], float],
149+
Tuple[Tuple[float64, float64, float64], float],
150+
]:
133151
dim = len(pts) - 1
134152
if dim == 2:
135153
return fast_2d_circumcircle(pts)
@@ -155,7 +173,7 @@ def circumsphere(pts):
155173
return tuple(center), radius
156174

157175

158-
def orientation(face, origin):
176+
def orientation(face: Any, origin: Any) -> Union[int, float64]:
159177
"""Compute the orientation of the face with respect to a point, origin.
160178
161179
Parameters
@@ -181,11 +199,18 @@ def orientation(face, origin):
181199
return sign
182200

183201

184-
def is_iterable_and_sized(obj):
202+
def is_iterable_and_sized(obj: Any) -> bool:
185203
return isinstance(obj, Iterable) and isinstance(obj, Sized)
186204

187205

188-
def simplex_volume_in_embedding(vertices) -> float:
206+
def simplex_volume_in_embedding(
207+
vertices: Union[
208+
List[Tuple[float64, float64, float64]],
209+
List[Tuple[float, float64, float64]],
210+
List[Tuple[float64, float64, float64, float64]],
211+
List[Tuple[float64, float64, float64, float64, float64]],
212+
]
213+
) -> float:
189214
"""Calculate the volume of a simplex in a higher dimensional embedding.
190215
That is: dim > len(vertices) - 1. For example if you would like to know the
191216
surface area of a triangle in a 3d space.
@@ -266,7 +291,7 @@ class Triangulation:
266291
or more simplices in the
267292
"""
268293

269-
def __init__(self, coords):
294+
def __init__(self, coords: Any) -> None:
270295
if not is_iterable_and_sized(coords):
271296
raise TypeError("Please provide a 2-dimensional list of points")
272297
coords = list(coords)
@@ -305,27 +330,27 @@ def __init__(self, coords):
305330
for simplex in initial_tri.simplices:
306331
self.add_simplex(simplex)
307332

308-
def delete_simplex(self, simplex):
333+
def delete_simplex(self, simplex: Any) -> None:
309334
simplex = tuple(sorted(simplex))
310335
self.simplices.remove(simplex)
311336
for vertex in simplex:
312337
self.vertex_to_simplices[vertex].remove(simplex)
313338

314-
def add_simplex(self, simplex):
339+
def add_simplex(self, simplex: Any) -> None:
315340
simplex = tuple(sorted(simplex))
316341
self.simplices.add(simplex)
317342
for vertex in simplex:
318343
self.vertex_to_simplices[vertex].add(simplex)
319344

320-
def get_vertices(self, indices):
345+
def get_vertices(self, indices: Any) -> Any:
321346
return [self.get_vertex(i) for i in indices]
322347

323-
def get_vertex(self, index):
348+
def get_vertex(self, index: Optional[Union[int32, int]]) -> Any:
324349
if index is None:
325350
return None
326351
return self.vertices[index]
327352

328-
def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
353+
def get_reduced_simplex(self, point: Any, simplex: Any, eps: float = 1e-8) -> list:
329354
"""Check whether vertex lies within a simplex.
330355
331356
Returns
@@ -350,11 +375,13 @@ def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
350375

351376
return [simplex[i] for i in result]
352377

353-
def point_in_simplex(self, point, simplex, eps=1e-8):
378+
def point_in_simplex(
379+
self, point: Any, simplex: Any, eps: float = 1e-8
380+
) -> Union[bool, bool_]:
354381
vertices = self.get_vertices(simplex)
355382
return point_in_simplex(point, vertices, eps)
356383

357-
def locate_point(self, point):
384+
def locate_point(self, point: Any) -> Any:
358385
"""Find to which simplex the point belongs.
359386
360387
Return indices of the simplex containing the point.
@@ -366,10 +393,12 @@ def locate_point(self, point):
366393
return ()
367394

368395
@property
369-
def dim(self):
396+
def dim(self) -> int:
370397
return len(self.vertices[0])
371398

372-
def faces(self, dim=None, simplices=None, vertices=None):
399+
def faces(
400+
self, dim: None = None, simplices: Optional[Any] = None, vertices: None = None
401+
) -> Iterator[Any]:
373402
"""Iterator over faces of a simplex or vertex sequence."""
374403
if dim is None:
375404
dim = self.dim
@@ -394,7 +423,7 @@ def containing(self, face):
394423
"""Simplices containing a face."""
395424
return set.intersection(*(self.vertex_to_simplices[i] for i in face))
396425

397-
def _extend_hull(self, new_vertex, eps=1e-8):
426+
def _extend_hull(self, new_vertex: Any, eps: float = 1e-8) -> Any:
398427
# count multiplicities in order to get all hull faces
399428
multiplicities = Counter(face for face in self.faces())
400429
hull_faces = [face for face, count in multiplicities.items() if count == 1]
@@ -434,7 +463,13 @@ def _extend_hull(self, new_vertex, eps=1e-8):
434463

435464
return new_simplices
436465

437-
def circumscribed_circle(self, simplex, transform):
466+
def circumscribed_circle(
467+
self, simplex: Any, transform: ndarray
468+
) -> Union[
469+
Tuple[Tuple[float64, float64, float64, float64], float],
470+
Tuple[Tuple[float64, float64], float],
471+
Tuple[Tuple[float64, float64, float64], float],
472+
]:
438473
"""Compute the center and radius of the circumscribed circle of a simplex.
439474
440475
Parameters
@@ -450,7 +485,9 @@ def circumscribed_circle(self, simplex, transform):
450485
pts = np.dot(self.get_vertices(simplex), transform)
451486
return circumsphere(pts)
452487

453-
def point_in_cicumcircle(self, pt_index, simplex, transform):
488+
def point_in_cicumcircle(
489+
self, pt_index: int, simplex: Any, transform: ndarray
490+
) -> bool_:
454491
# return self.fast_point_in_circumcircle(pt_index, simplex, transform)
455492
eps = 1e-8
456493

@@ -460,10 +497,15 @@ def point_in_cicumcircle(self, pt_index, simplex, transform):
460497
return np.linalg.norm(center - pt) < (radius * (1 + eps))
461498

462499
@property
463-
def default_transform(self):
500+
def default_transform(self) -> ndarray:
464501
return np.eye(self.dim)
465502

466-
def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
503+
def bowyer_watson(
504+
self,
505+
pt_index: int,
506+
containing_simplex: Optional[Any] = None,
507+
transform: Optional[ndarray] = None,
508+
) -> Any:
467509
"""Modified Bowyer-Watson point adding algorithm.
468510
469511
Create a hole in the triangulation around the new point,
@@ -523,10 +565,10 @@ def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
523565
new_triangles = self.vertex_to_simplices[pt_index]
524566
return bad_triangles - new_triangles, new_triangles - bad_triangles
525567

526-
def _simplex_is_almost_flat(self, simplex):
568+
def _simplex_is_almost_flat(self, simplex: Any) -> bool_:
527569
return self._relative_volume(simplex) < 1e-8
528570

529-
def _relative_volume(self, simplex):
571+
def _relative_volume(self, simplex: Any) -> float64:
530572
"""Compute the volume of a simplex divided by the average (Manhattan)
531573
distance of its vertices. The advantage of this is that the relative
532574
volume is only dependent on the shape of the simplex and not on the
@@ -537,7 +579,12 @@ def _relative_volume(self, simplex):
537579
average_edge_length = np.mean(np.abs(vectors))
538580
return self.volume(simplex) / (average_edge_length ** self.dim)
539581

540-
def add_point(self, point, simplex=None, transform=None):
582+
def add_point(
583+
self,
584+
point: Any,
585+
simplex: Optional[Any] = None,
586+
transform: Optional[ndarray] = None,
587+
) -> Any:
541588
"""Add a new vertex and create simplices as appropriate.
542589
543590
Parameters
@@ -586,13 +633,13 @@ def add_point(self, point, simplex=None, transform=None):
586633
self.vertices.append(point)
587634
return self.bowyer_watson(pt_index, actual_simplex, transform)
588635

589-
def volume(self, simplex):
636+
def volume(self, simplex: Any) -> float:
590637
prefactor = np.math.factorial(self.dim)
591638
vertices = np.array(self.get_vertices(simplex))
592639
vectors = vertices[1:] - vertices[0]
593640
return float(abs(fast_det(vectors)) / prefactor)
594641

595-
def volumes(self):
642+
def volumes(self) -> List[float]:
596643
return [self.volume(sim) for sim in self.simplices]
597644

598645
def reference_invariant(self):
@@ -609,21 +656,29 @@ def vertex_invariant(self, vertex):
609656
"""Simplices originating from a vertex don't overlap."""
610657
raise NotImplementedError
611658

612-
def get_neighbors_from_vertices(self, simplex):
659+
def get_neighbors_from_vertices(self, simplex: Any) -> Any:
613660
return set.union(*[self.vertex_to_simplices[p] for p in simplex])
614661

615-
def get_face_sharing_neighbors(self, neighbors, simplex):
662+
def get_face_sharing_neighbors(self, neighbors: Any, simplex: Any) -> Any:
616663
"""Keep only the simplices sharing a whole face with simplex."""
617664
return {
618665
simpl for simpl in neighbors if len(set(simpl) & set(simplex)) == self.dim
619666
} # they share a face
620667

621-
def get_simplices_attached_to_points(self, indices):
668+
def get_simplices_attached_to_points(self, indices: Any) -> Any:
622669
# Get all simplices that share at least a point with the simplex
623670
neighbors = self.get_neighbors_from_vertices(indices)
624671
return self.get_face_sharing_neighbors(neighbors, indices)
625672

626-
def get_opposing_vertices(self, simplex):
673+
def get_opposing_vertices(
674+
self,
675+
simplex: Union[
676+
Tuple[int32, int, int],
677+
Tuple[int32, int32, int],
678+
Tuple[int32, int32, int32],
679+
Tuple[int, int, int],
680+
],
681+
) -> Any:
627682
if simplex not in self.simplices:
628683
raise ValueError("Provided simplex is not part of the triangulation")
629684
neighbors = self.get_simplices_attached_to_points(simplex)
@@ -641,7 +696,7 @@ def find_opposing_vertex(vertex):
641696
return result
642697

643698
@property
644-
def hull(self):
699+
def hull(self) -> Union[Set[int32], Set[int], Set[Union[int32, int]]]:
645700
"""Compute hull from triangulation.
646701
647702
Parameters

0 commit comments

Comments
 (0)