Skip to content

Commit 8c4bbc9

Browse files
authored
Merge pull request #99 from MunchLab/fix-mypy-types
Add type hinting for `all_coords` parameter in validation functions and
2 parents e6040c2 + 2b892b9 commit 8c4bbc9

File tree

5 files changed

+107
-72
lines changed

5 files changed

+107
-72
lines changed

src/ect/dect.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .results import ECTResult
55
from typing import Optional, Union
66
import numpy as np
7-
from numba import njit
7+
from numba import njit # type: ignore[attr-defined]
88

99

1010
class DECT(ECT):
@@ -86,18 +86,19 @@ def _compute_directional_transform(
8686
def calculate(
8787
self,
8888
graph: Union[EmbeddedGraph, EmbeddedCW],
89-
scale: Optional[float] = None,
9089
theta: Optional[float] = None,
9190
override_bound_radius: Optional[float] = None,
91+
*,
92+
scale: Optional[float] = None,
9293
) -> ECTResult:
9394
"""
9495
Calculate the Differentiable Euler Characteristic Transform (DECT) for a given embedded complex.
9596
9697
Args:
9798
graph (EmbeddedGraph or EmbeddedCW): The embedded complex to analyze.
98-
scale (Optional[float]): Slope parameter for the sigmoid function. If None, uses the instance's scale.
9999
theta (Optional[float]): Specific direction angle to use. If None, uses all directions.
100100
override_bound_radius (Optional[float]): Override for bounding radius in threshold generation.
101+
scale (Optional[float]): Slope parameter for the sigmoid function. If None, uses the instance's scale.
101102
102103
Returns:
103104
ECTResult: Result object containing the DECT matrix, directions, and thresholds.

src/ect/ect.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from numba import prange, njit
2+
from numba import prange, njit # type: ignore[attr-defined]
33
from numba.typed import List
44
from typing import Optional
55

@@ -16,7 +16,7 @@ class ECT:
1616
The result is a matrix where entry ``M[i,j]`` is :math:`\chi(K_{a_i})` for the direction :math:`\omega_j`
1717
where :math:`a_i` is the ith entry in ``self.thresholds``, and :math:`\omega_j` is the jth entry in ``self.directions``.
1818
19-
19+
2020
2121
Example:
2222
>>> from ect import ECT, EmbeddedComplex
@@ -106,14 +106,15 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
106106
def calculate(
107107
self,
108108
graph: EmbeddedComplex,
109-
theta: float = None,
110-
override_bound_radius: float = None,
109+
theta: Optional[float] = None,
110+
override_bound_radius: Optional[float] = None,
111111
):
112112
self._ensure_directions(graph.dim, theta)
113113
self._ensure_thresholds(graph, override_bound_radius)
114114
directions = (
115115
self.directions if theta is None else Directions.from_angles([theta])
116116
)
117+
assert self.thresholds is not None
117118
ect_matrix = self._compute_ect(graph, directions, self.thresholds, self.dtype)
118119

119120
return ECTResult(ect_matrix, directions, self.thresholds)

src/ect/embed_complex.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,8 @@ def pca_projection(self, target_dim=2):
655655
pca = PCA(n_components=target_dim)
656656
self._coord_matrix = pca.fit_transform(self._coord_matrix)
657657

658-
658+
@staticmethod
659659
def validate_plot_parameters(func):
660-
# decorator to check plotting requirements
661660
@functools.wraps(func)
662661
def wrapper(self, *args, **kwargs):
663662
bounding_center_type = kwargs.get("bounding_center_type", "bounding_box")
@@ -706,7 +705,6 @@ def plot_faces(self, ax=None, **kwargs):
706705

707706
return ax
708707

709-
710708
@validate_plot_parameters
711709
def plot(
712710
self,
@@ -722,7 +720,7 @@ def plot(
722720
face_color: str = "lightblue",
723721
face_alpha: float = 0.3,
724722
**kwargs,
725-
) -> plt.Axes:
723+
) -> plt.Axes:
726724
"""
727725
Visualize the embedded complex in 2D or 3D
728726
@@ -739,7 +737,7 @@ def plot(
739737
face_color (str): Color for faces (2-cells)
740738
face_alpha (float): Transparency for faces (2-cells)
741739
**kwargs: Additional keyword arguments for plotting functions
742-
740+
743741
Returns:
744742
matplotlib.axes.Axes: The axes object with the plot.
745743
"""
@@ -991,20 +989,20 @@ def _build_incidence_csr(self) -> tuple:
991989

992990
cell_vertex_pointers = np.empty(n_cells + 1, dtype=np.int64)
993991
cell_euler_signs = np.empty(n_cells, dtype=np.int32)
994-
cell_vertex_indices_flat = []
992+
list_flat: List[int] = []
995993

996994
cell_vertex_pointers[0] = 0
997995
cell_index = 0
998996
for dim in dimensions:
999997
cells_in_dim = cells_by_dimension[dim]
1000998
euler_sign = 1 if (dim % 2 == 0) else -1
1001999
for cell_vertices in cells_in_dim:
1002-
cell_vertex_indices_flat.extend(cell_vertices)
1000+
list_flat.extend(cell_vertices)
10031001
cell_euler_signs[cell_index] = euler_sign
10041002
cell_index += 1
1005-
cell_vertex_pointers[cell_index] = len(cell_vertex_indices_flat)
1003+
cell_vertex_pointers[cell_index] = len(list_flat)
10061004

1007-
cell_vertex_indices_flat = np.asarray(cell_vertex_indices_flat, dtype=np.int32)
1005+
cell_vertex_indices_flat = np.asarray(list_flat, dtype=np.int32)
10081006
return (
10091007
cell_vertex_pointers,
10101008
cell_vertex_indices_flat,

src/ect/results.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from ect.directions import Sampling
44
from scipy.spatial.distance import cdist, pdist, squareform
5-
from typing import Union, List, Callable
5+
from typing import Union, List, Callable, cast
66

77

88
# ---------- CSR <-> Dense helpers (prefix-difference over thresholds) ----------
@@ -352,14 +352,15 @@ def dist(
352352
>>> # Batch distances with custom function
353353
>>> dists = ect1.dist([ect2, ect3, ect4], metric=my_distance)
354354
"""
355-
# normalize input to list
356355
single = isinstance(other, ECTResult)
357-
others = [other] if single else other
356+
others_list: List["ECTResult"] = cast(
357+
List["ECTResult"], [other] if single else other
358+
)
358359

359-
if not others:
360+
if not others_list:
360361
return np.array([])
361362

362-
for i, ect in enumerate(others):
363+
for i, ect in enumerate(others_list):
363364
if ect.shape != self.shape:
364365
raise ValueError(
365366
f"Shape mismatch at index {i}: {self.shape} vs {ect.shape}"
@@ -370,13 +371,15 @@ def dist(
370371
if single:
371372
b = np.asarray(other, dtype=np.float64)
372373
return float(np.sqrt(np.sum((a - b) ** 2)))
373-
b = np.stack([np.asarray(ect, dtype=np.float64) for ect in others], axis=0)
374+
b = np.stack(
375+
[np.asarray(ect, dtype=np.float64) for ect in others_list], axis=0
376+
)
374377
diff = b - a
375378
return np.sqrt(np.sum(diff * diff, axis=(1, 2)))
376379

377380
distances = cdist(
378381
self.ravel()[np.newaxis, :],
379-
np.vstack([ect.ravel() for ect in others]),
382+
np.vstack([ect.ravel() for ect in others_list]),
380383
metric=metric,
381384
**kwargs,
382385
)[0]
@@ -399,13 +402,25 @@ def dist_matrix(
399402
raise ValueError(f"Shape mismatch at index {i}: {shape0} vs {r.shape}")
400403

401404
if isinstance(metric, str) and metric.lower() in ("frobenius", "fro"):
402-
return np.vstack([results[i].dist(results, metric="frobenius") for i in range(len(results))])
405+
return np.vstack(
406+
[
407+
results[i].dist(results, metric="frobenius")
408+
for i in range(len(results))
409+
]
410+
)
403411

404412
if isinstance(metric, str):
405-
X = np.stack([np.asarray(r, dtype=np.float64).ravel() for r in results], axis=0)
413+
X = np.stack(
414+
[np.asarray(r, dtype=np.float64).ravel() for r in results], axis=0
415+
)
406416
try:
407417
return squareform(pdist(X, metric=metric, **kwargs))
408418
except TypeError:
409419
return cdist(X, X, metric=metric, **kwargs)
410420

411-
return np.vstack([results[i].dist(results, metric=metric, **kwargs) for i in range(len(results))])
421+
return np.vstack(
422+
[
423+
results[i].dist(results, metric=metric, **kwargs)
424+
for i in range(len(results))
425+
]
426+
)

0 commit comments

Comments
 (0)