Skip to content

Commit cc29d76

Browse files
committed
Fix typing
1 parent 6db2e40 commit cc29d76

33 files changed

+243
-268
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454

5555
# Unit testing
5656
- name: Run unit tests & collect coverage
57-
run: pytest tests --cov=manify --cov-report=xml:coverage.xml
57+
run: pytest BEARTYPE_ENABLE=true tests --cov=manify --cov-report=xml:coverage.xml
5858

5959
# Check docstrings are in Google style
6060
- name: Check docstrings are in Google style

manify/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
"""Manify: A Python Library for Learning Non-Euclidean Representations."""
22

3-
from jaxtyping import install_import_hook
3+
if os.getenv("BEARTYPE_ENABLE", "false").lower() == "true":
4+
from jaxtyping import install_import_hook
45

5-
install_import_hook("manify", "beartype.beartype")
6+
install_import_hook("manify", "beartype.beartype")
67

7-
from manify.curvature_estimation import (
8-
delta_hyperbolicity,
9-
greedy_signature_selection,
10-
sampled_delta_hyperbolicity,
11-
sectional_curvature,
12-
)
8+
from manify.curvature_estimation import greedy_signature_selection, sampled_delta_hyperbolicity, sectional_curvature
139
from manify.embedders import CoordinateLearning, ProductSpaceVAE, SiameseNetwork
1410
from manify.manifolds import Manifold, ProductManifold
1511
from manify.predictors import KappaGCN, ProductSpaceDT, ProductSpacePerceptron, ProductSpaceRF, ProductSpaceSVM

manify/clustering/fuzzy_kmeans.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020

2121
from __future__ import annotations
2222

23-
from typing import Literal, Optional, Union
24-
2523
import numpy as np
2624
import torch
25+
from beartype.typing import Literal
2726
from geoopt import ManifoldParameter
2827
from geoopt.optim import RiemannianAdam
2928
from jaxtyping import Float, Int
@@ -66,13 +65,13 @@ class RiemannianFuzzyKMeans(BaseEstimator, ClusterMixin):
6665
def __init__(
6766
self,
6867
n_clusters: int,
69-
manifold: Union[Manifold, ProductManifold],
68+
manifold: Manifold | ProductManifold,
7069
m: float = 2.0,
7170
lr: float = 0.1,
7271
max_iter: int = 100,
7372
tol: float = 1e-4,
7473
optimizer: Literal["adan", "adam"] = "adan",
75-
random_state: Optional[int] = None,
74+
random_state: int | None = None,
7675
verbose: bool = False,
7776
):
7877
self.n_clusters = n_clusters

manify/curvature_estimation/delta_hyperbolicity.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Tuple
15-
1614
import torch
17-
from jaxtyping import Float
15+
from jaxtyping import Float, Int
1816

1917

2018
def sampled_delta_hyperbolicity(
2119
D: Float[torch.Tensor, "n_points n_points"], n_samples: int = 1000, reference_idx: int = 0, relative: bool = True
22-
) -> Tuple[Float[torch.Tensor, "n_samples"], Float[torch.Tensor, "n_samples 3"]]:
20+
) -> tuple[Float[torch.Tensor, "n_samples"], Int[torch.Tensor, "n_samples 3"]]:
2321
r"""Computes $\delta$-hyperbolicity by sampling random point triplets.
2422
2523
For large metric spaces, this approximates $\delta$-hyperbolicity by randomly sampling triplets. For each triplet

manify/curvature_estimation/greedy_method.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from __future__ import annotations
88

9-
from typing import Any, Tuple
10-
119
import torch
1210
from jaxtyping import Float
1311

@@ -17,9 +15,9 @@
1715
def greedy_signature_selection(
1816
pm: ProductManifold,
1917
dists: Float[torch.Tensor, "n_points n_points"],
20-
candidate_components: Tuple[Tuple[float, int], ...] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
18+
candidate_components: tuple[tuple[float, int], ...] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
2119
max_components: int = 3,
22-
) -> Any:
20+
) -> None:
2321
r"""Greedily estimates an optimal product manifold signature.
2422
2523
This implements the greedy signature selection algorithm that incrementally builds a product manifold

manify/embedders/_base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6-
from typing import Any, Dict, List, Optional
76

87
import torch
8+
from beartype.typing import Any
99
from jaxtyping import Float
1010
from sklearn.base import BaseEstimator, TransformerMixin
1111

@@ -27,18 +27,18 @@ class BaseEmbedder(BaseEstimator, TransformerMixin, ABC):
2727
is_fitted_: Boolean flag indicating if the embedder has been fitted.
2828
"""
2929

30-
def __init__(self, pm: ProductManifold, random_state: Optional[int] = None, device: Optional[str] = None) -> None:
30+
def __init__(self, pm: ProductManifold, random_state: int | None = None, device: str | None = None) -> None:
3131
self.pm = pm
3232
self.random_state = random_state
3333
self.device = pm.device if device is None else device
34-
self.loss_history_: Dict[str, List[float]] = {}
34+
self.loss_history_: dict[str, list[float]] = {}
3535
self.is_fitted_: bool = False
3636

3737
@abstractmethod
3838
def fit(
3939
self,
40-
X: Optional[Float[torch.Tensor, "n_points n_features"]] = None,
41-
D: Optional[Float[torch.Tensor, "n_points n_points"]] = None,
40+
X: Float[torch.Tensor, "n_points n_features"] | None = None,
41+
D: Float[torch.Tensor, "n_points n_points"] | None = None,
4242
lr: float = 1e-2,
4343
burn_in_lr: float = 1e-3,
4444
curvature_lr: float = 0.0, # Off by default
@@ -67,7 +67,7 @@ def fit(
6767

6868
@abstractmethod
6969
def transform(
70-
self, X: Optional[Float[torch.Tensor, "n_points n_features"]]
70+
self, X: Float[torch.Tensor, "n_points n_features"] | None
7171
) -> Float[torch.Tensor, "n_points embedding_dim"]:
7272
"""Apply embedding to new data. Not defined for coordinate learning.
7373
@@ -81,8 +81,8 @@ def transform(
8181

8282
def fit_transform(
8383
self,
84-
X: Optional[Float[torch.Tensor, "n_points n_features"]] = None,
85-
D: Optional[Float[torch.Tensor, "n_points n_points"]] = None,
84+
X: Float[torch.Tensor, "n_points n_features"] | None = None,
85+
D: Float[torch.Tensor, "n_points n_points"] | None = None,
8686
**fit_kwargs: Any,
8787
) -> Float[torch.Tensor, "n_points embedding_dim"]:
8888
"""Fit the embedder and transform the data in one step.

manify/embedders/_losses.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77

88
from __future__ import annotations
99

10-
from typing import List
11-
1210
import networkx as nx
1311
import torch
1412
from jaxtyping import Float
1513

1614
from ..manifolds import ProductManifold
1715

16+
# TODO: Fix shape annotations for Float tensors with "..." placeholders
17+
1818

1919
def distortion_loss(
20-
D_est: Float[torch.Tensor, "n_points n_points"],
21-
D_true: Float[torch.Tensor, "n_points n_points"],
20+
D_est: Float[torch.Tensor, "..."],
21+
D_true: Float[torch.Tensor, "..."],
2222
pairwise: bool = False,
2323
) -> Float[torch.Tensor, ""]:
2424
r"""Computes the distortion loss between estimated and true squared distances.
@@ -59,8 +59,8 @@ def distortion_loss(
5959

6060

6161
def d_avg(
62-
D_est: Float[torch.Tensor, "n_points n_points"],
63-
D_true: Float[torch.Tensor, "n_points n_points"],
62+
D_est: Float[torch.Tensor, "..."],
63+
D_true: Float[torch.Tensor, "..."],
6464
pairwise: bool = False,
6565
) -> Float[torch.Tensor, ""]:
6666
r"""Computes the average relative distance error (D_avg).
@@ -102,7 +102,9 @@ def d_avg(
102102
return torch.mean(torch.abs(D_est - D_true) / D_true)
103103

104104

105-
def mean_average_precision(x_embed: Float[torch.Tensor, "n_points n_dim"], graph: nx.Graph) -> Float[torch.Tensor, ""]:
105+
def mean_average_precision(
106+
x_embed: Float[torch.Tensor, "n_points_dists n_dim"], graph: nx.Graph
107+
) -> Float[torch.Tensor, ""]:
106108
r"""Computes the mean average precision (mAP) for graph embedding evaluation.
107109
108110
This metric is used to evaluate how well an embedding preserves the neighborhood structure of a graph, as described
@@ -121,7 +123,9 @@ def mean_average_precision(x_embed: Float[torch.Tensor, "n_points n_dim"], graph
121123
raise NotImplementedError
122124

123125

124-
def dist_component_by_manifold(pm: ProductManifold, x_embed: Float[torch.Tensor, "n_points n_dim"]) -> List[float]:
126+
def dist_component_by_manifold(
127+
pm: ProductManifold, x_embed: Float[torch.Tensor, "n_points_dists n_dim"]
128+
) -> list[float]:
125129
r"""Computes the proportion of variance in pairwise distances explained by each manifold component.
126130
127131
The contribution is calculated as the ratio of the sum of squared distances in each component to the total squared

manify/embedders/coordinate_learning.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
"""Implementation for direct coordinate optimization in Riemannian manifolds.
22
33
This module provides functions for learning optimal embeddings in product manifolds by directly optimizing the
4-
coordinates using Riemannian optimization. This approach is particularly useful for embedding graphs using metric learning
5-
to maintain pairwise distances in the target space. The optimization is performed using Riemannian gradient descent
6-
with support for non-transductive training, in which gradients from the test set to the training set are masked out.
4+
coordinates using Riemannian optimization. This approach is particularly useful for embedding graphs using metric
5+
learning to maintain pairwise distances in the target space. The optimization is performed using Riemannian gradient
6+
descent with support for non-transductive training, in which gradients from the test set to the training set are masked
7+
out.
78
"""
89

910
from __future__ import annotations
1011

1112
import sys
1213
import warnings
13-
from typing import Any, Dict, List, Optional
1414

1515
import geoopt
1616
import numpy as np
1717
import torch
18+
from beartype.typing import Any
1819
from jaxtyping import Float, Int
1920

2021
from ..manifolds import ProductManifold
@@ -62,7 +63,7 @@ class CoordinateLearning(BaseEmbedder):
6263
device: Optional device for tensor computations.
6364
"""
6465

65-
def __init__(self, pm: ProductManifold, random_state: Optional[int] = None, device: Optional[str] = None) -> None:
66+
def __init__(self, pm: ProductManifold, random_state: int | None = None, device: str | None = None) -> None:
6667
super().__init__(pm=pm, random_state=random_state, device=device)
6768

6869
def fit( # type: ignore[override]
@@ -105,7 +106,8 @@ def fit( # type: ignore[override]
105106
raise ValueError("Distance matrix D is needed for coordinate learning")
106107
if X is not None:
107108
warnings.warn(
108-
"Input X has been given. This will be ignored during fitting. If you have provided a distance matrix, please run embedder.fit(None, D) instead."
109+
"Input X has been given. This will be ignored during fitting. If you have provided a distance matrix,"
110+
"please run embedder.fit(None, D) instead."
109111
)
110112

111113
# Set random seed if provided
@@ -115,7 +117,7 @@ def fit( # type: ignore[override]
115117
# Move everything to the device; initialize random embeddings
116118
n = D.shape[0]
117119
covs = [torch.stack([torch.eye(M.dim) / self.pm.dim] * n).to(self.device) for M in self.pm.P]
118-
means = torch.stack([self.pm.mu0] * n).to(self.device)
120+
means = torch.vstack([self.pm.mu0] * n).to(self.device)
119121
X_embed, _ = self.pm.sample(z_mean=means, sigma_factorized=covs)
120122
D = D.to(self.device)
121123

@@ -134,7 +136,7 @@ def fit( # type: ignore[override]
134136
my_tqdm = tqdm(total=burn_in_iterations + training_iterations, leave=False)
135137

136138
# Outer training loop - mostly setting optimizer learning rates up here
137-
losses: Dict[str, List[float]] = {"train_train": [], "test_test": [], "train_test": [], "total": []}
139+
losses: dict[str, list[float]] = {"train_train": [], "test_test": [], "train_test": [], "total": []}
138140

139141
# Actual training loop
140142
for i in range(burn_in_iterations + training_iterations):

manify/embedders/siamese.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from __future__ import annotations
1212

1313
import sys
14-
from typing import Dict, List, Optional, Tuple
1514

1615
import numpy as np
1716
import torch
@@ -59,10 +58,10 @@ def __init__(
5958
self,
6059
pm: ProductManifold,
6160
encoder: torch.nn.Module,
62-
decoder: Optional[torch.nn.Module] = None,
61+
decoder: torch.nn.Module | None = None,
6362
reconstruction_loss: str = "mse",
6463
beta: float = 1.0,
65-
random_state: Optional[int] = None,
64+
random_state: int | None = None,
6665
device: str = "cpu",
6766
):
6867
# Init both base classes

manify/embedders/vae.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from __future__ import annotations
1212

1313
import sys
14-
from typing import Dict, List, Optional, Tuple
1514

1615
import numpy as np
1716
import torch
@@ -74,7 +73,7 @@ def __init__(
7473
pm: ProductManifold,
7574
encoder: torch.nn.Module,
7675
decoder: torch.nn.Module,
77-
random_state: Optional[int] = None,
76+
random_state: int | None = None,
7877
device: str = "cpu",
7978
beta: float = 1.0,
8079
reconstruction_loss: torch.nn.modules.loss._Loss = torch.nn.MSELoss(reduction="none"),
@@ -102,7 +101,7 @@ def __init__(
102101

103102
def encode(
104103
self, x: Float[torch.Tensor, "batch_size n_features"]
105-
) -> Tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]:
104+
) -> tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]:
106105
r"""Encodes input data to obtain latent means and log-variances in the manifold.
107106
108107
This method processes input data through the encoder network to obtain parameters of the approximate posterior
@@ -140,10 +139,10 @@ def decode(self, z: Float[torch.Tensor, "batch_size n_ambient"]) -> Float[torch.
140139
"""
141140
return self.decoder(z)
142141

143-
def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Tuple[
142+
def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> tuple[
144143
Float[torch.Tensor, "batch_size n_features"],
145144
Float[torch.Tensor, "batch_size n_ambient"],
146-
List[Float[torch.Tensor, "n_latent n_latent"]],
145+
list[Float[torch.Tensor, "batch_size n_latent n_latent"]],
147146
]:
148147
r"""Performs the forward pass of the VAE in product manifold space.
149148
@@ -181,7 +180,7 @@ def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Tuple[
181180
def kl_divergence(
182181
self,
183182
z_mean: Float[torch.Tensor, "batch_size n_latent"],
184-
sigma_factorized: List[Float[torch.Tensor, "n_latent n_latent"]],
183+
sigma_factorized: list[Float[torch.Tensor, "batch_size manifold_dim manifold_dim"]],
185184
) -> Float[torch.Tensor, "batch_size"]:
186185
r"""Computes the KL divergence between posterior and prior distributions in the manifold.
187186
@@ -214,7 +213,7 @@ def kl_divergence(
214213

215214
def elbo(
216215
self, x: Float[torch.Tensor, "batch_size n_features"]
217-
) -> Tuple[Float[torch.Tensor, ""], Float[torch.Tensor, ""], Float[torch.Tensor, ""]]:
216+
) -> tuple[Float[torch.Tensor, ""], Float[torch.Tensor, ""], Float[torch.Tensor, ""]]:
218217
r"""Computes the Evidence Lower Bound (ELBO) for the VAE objective.
219218
220219
The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term

0 commit comments

Comments
 (0)