Skip to content

Commit a599aa7

Browse files
committed
Fix jaxtyping annotations
1 parent 9870369 commit a599aa7

File tree

19 files changed

+495
-196
lines changed

19 files changed

+495
-196
lines changed

manify/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Manify: A Python Library for Learning Non-Euclidean Representations."""
22

3+
from jaxtyping import install_import_hook
4+
5+
install_import_hook("manify", "beartype.beartype")
6+
37
from manify.curvature_estimation import (
48
delta_hyperbolicity,
59
greedy_signature_selection,

manify/curvature_estimation/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
* `sectional_curvature`: Estimates the sectional curvature of a graph from its distance matrix.
88
"""
99

10-
from manify.curvature_estimation.delta_hyperbolicity import delta_hyperbolicity, sampled_delta_hyperbolicity
10+
from manify.curvature_estimation.delta_hyperbolicity import vectorized_delta_hyperbolicity, sampled_delta_hyperbolicity
1111
from manify.curvature_estimation.greedy_method import greedy_signature_selection
1212
from manify.curvature_estimation.sectional_curvature import sectional_curvature
1313

14-
__all__ = ["greedy_signature_selection", "sectional_curvature", "sampled_delta_hyperbolicity", "delta_hyperbolicity"]
14+
__all__ = [
15+
"greedy_signature_selection",
16+
"sectional_curvature",
17+
"sampled_delta_hyperbolicity",
18+
"vectorized_delta_hyperbolicity",
19+
]

manify/curvature_estimation/delta_hyperbolicity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
def sampled_delta_hyperbolicity(
2121
D: Float[torch.Tensor, "n_points n_points"], n_samples: int = 1000, reference_idx: int = 0, relative: bool = True
22-
) -> Tuple[Float[torch.Tensor, "n_samples,"], Float[torch.Tensor, "n_samples 3"]]:
22+
) -> Tuple[Float[torch.Tensor, "n_samples"], Float[torch.Tensor, "n_samples 3"]]:
2323
r"""Computes $\delta$-hyperbolicity by sampling random point triplets.
2424
2525
For large metric spaces, this approximates $\delta$-hyperbolicity by randomly sampling triplets. For each triplet
@@ -61,7 +61,7 @@ def sampled_delta_hyperbolicity(
6161
return deltas, indices
6262

6363

64-
def delta_hyperbolicity(
64+
def vectorized_delta_hyperbolicity(
6565
D: Float[torch.Tensor, "n_points n_points"], reference_idx: int = 0, relative: bool = True, full: bool = False
6666
) -> Float[torch.Tensor, "n_points n_points n_points"]:
6767
r"""Computes the exact delta-hyperbolicity of a metric space over all point triplets.

manify/curvature_estimation/greedy_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
from typing import Any, Tuple
10+
from jaxtyping import Float
1011

1112
import torch
1213

@@ -15,7 +16,7 @@
1516

1617
def greedy_signature_selection(
1718
pm: ProductManifold,
18-
dists: torch.Tensor,
19+
dists: Float[torch.Tensor, "n_points n_points"],
1920
candidate_components: Tuple[Tuple[float, int], ...] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
2021
max_components: int = 3,
2122
) -> Any:

manify/embedders/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class BaseEmbedder(BaseEstimator, TransformerMixin, ABC):
16-
"""Base class for everything in manify.embedders.
16+
"""Base class for everything in `manify.embedders`.
1717
1818
This is an abstract class that that defines a common interface for all embedding methods. We assume only that a
1919
ProductManifold object is given. We try to follow the scikit-learn API's fit/transform paradigm as closely as

manify/embedders/coordinate_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def fit( # type: ignore[override]
7070
self,
7171
X: None,
7272
D: Float[torch.Tensor, "n_points n_points"],
73-
test_indices: Int[torch.Tensor, "n_test,"] = torch.tensor([]),
73+
test_indices: Int[torch.Tensor, "n_test"] = torch.tensor([]),
7474
lr: float = 1e-2,
7575
burn_in_lr: float = 1e-3,
7676
curvature_lr: float = 0.0, # Off by default

manify/embedders/siamese.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def forward(
118118
) -> Tuple[
119119
Float[torch.Tensor, "batch_size n_latent"],
120120
Float[torch.Tensor, "batch_size n_latent"],
121-
Float[torch.Tensor, "batch_size,"],
121+
Float[torch.Tensor, "batch_size"],
122122
Float[torch.Tensor, "batch_size n_features"],
123123
Float[torch.Tensor, "batch_size n_features"],
124124
]:

manify/embedders/vae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def kl_divergence(
182182
self,
183183
z_mean: Float[torch.Tensor, "batch_size n_latent"],
184184
sigma_factorized: List[Float[torch.Tensor, "n_latent n_latent"]],
185-
) -> Float[torch.Tensor, "batch_size,"]:
185+
) -> Float[torch.Tensor, "batch_size"]:
186186
r"""Computes the KL divergence between posterior and prior distributions in the manifold.
187187
188188
For distributions in Riemannian manifolds, computing the KL divergence analytically
@@ -358,8 +358,8 @@ def fit( # type: ignore[override]
358358
return self
359359

360360
def transform(
361-
self, X: Float["n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
362-
) -> Float["n_points embedding_dim"]:
361+
self, X: Float[torch.Tensor, "n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
362+
) -> Float[torch.Tensor, "n_points embedding_dim"]:
363363
"""Transform data using the trained VAE. Outputs means of the variational distribution.
364364
365365
Args:

manify/manifolds.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def log_likelihood(
262262
z: Float[torch.Tensor, "n_points n_ambient_dim"],
263263
mu: Optional[Float[torch.Tensor, "n_points n_ambient_dim"]] = None,
264264
sigma: Optional[Float[torch.Tensor, "n_points n_dim n_dim"]] = None,
265-
) -> Float[torch.Tensor, "n_points,"]:
265+
) -> Float[torch.Tensor, "n_points"]:
266266
r"""Compute probability density function for $\mathcal{WN}(\mathbf{z}; \mu, \Sigma)$ on the manifold.
267267
268268
Args:
@@ -646,9 +646,9 @@ def sample(
646646
def log_likelihood(
647647
self,
648648
z: Float[torch.Tensor, "batch_size n_dim"],
649-
mu: Optional[Float[torch.Tensor, "n_dim,"]] = None,
649+
mu: Optional[Float[torch.Tensor, "n_dim"]] = None,
650650
sigma_factorized: Optional[List[Float[torch.Tensor, "n_points n_dim_manifold n_dim_manifold"]]] = None,
651-
) -> Float[torch.Tensor, "batch_size,"]:
651+
) -> Float[torch.Tensor, "batch_size"]:
652652
r"""Compute probability density function for $\mathcal{WN}(\mathbf{z} ; \mu, \Sigma)$ on the product manifold.
653653
654654
Args:
@@ -756,7 +756,7 @@ def gaussian_mixture(
756756
regression_noise_std: float = 0.1,
757757
task: Literal["classification", "regression"] = "classification",
758758
adjust_for_dims: bool = False,
759-
) -> Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points,"]]:
759+
) -> Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points"]]:
760760
"""Generate a set of labeled samples from a Gaussian mixture model.
761761
762762
Args:

manify/predictors/_base.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Base predictor class."""
2+
3+
from __future__ import annotations
4+
5+
from abc import ABC, abstractmethod
6+
from typing import Any, Dict, List, Optional
7+
8+
import torch
9+
from jaxtyping import Float
10+
from sklearn.base import BaseEstimator
11+
12+
from ..manifolds import ProductManifold
13+
14+
15+
class BasePredictor(BaseEstimator, ABC):
16+
"""Base class for everything in `manify.predictors`.
17+
18+
This is an abstract class that defines a common interface for all mixed-curvature predictors. We assume only that a
19+
ProductManifold object is given. We try to follow the scikit-learn API's fit/predict_proba/predict paradigm as
20+
closely as possible, while accommodating the nuances of product manifold geometry and Pytorch/Geoopt.
21+
22+
Attributes:
23+
pm: ProductManifold object associated with the predictor.
24+
task: Task type, either "classification" or "regression".
25+
random_state: Random state for reproducibility.
26+
device: Device for tensor computations. If not provided, defaults to pm.device.
27+
loss_history_: History of loss values during training.
28+
is_fitted_: Boolean flag indicating if the predictor has been fitted.
29+
"""
30+
31+
def __init__(
32+
self,
33+
pm: ProductManifold,
34+
task: Literal["classification", "regression"],
35+
random_state: Optional[int] = None,
36+
device: Optional[str] = None,
37+
) -> None:
38+
self.pm = pm
39+
self.task = task
40+
self.random_state = random_state
41+
self.device = pm.device if device is None else device
42+
self.loss_history_: Dict[str, List[float]] = {}
43+
self.is_fitted_: bool = False
44+
45+
@abstractmethod
46+
def fit(
47+
self, X: Float[torch.Tensor, "n_points n_features"], y: Float[torch.Tensor, "n_points n_classes"]
48+
) -> "BasePredictor":
49+
"""Abstract method to fit a predictor. Requires features and labels.
50+
51+
Args:
52+
X: Features to fit.
53+
y: Labels for the features.
54+
55+
Returns:
56+
self: Fitted predictor instance.
57+
"""
58+
pass
59+
60+
@abstractmethod
61+
def predict_proba(
62+
self, X: Optional[Float[torch.Tensor, "n_points n_features"]]
63+
) -> Float[torch.Tensor, "n_points n_classes"]:
64+
"""Compute the predicted probabilities for the given features.
65+
66+
Args:
67+
X: New inputs for which to make predictions.
68+
69+
Returns:
70+
X_proba: Predicted probabilities for the input features.
71+
"""
72+
pass
73+
74+
def predict(
75+
self, X: Optional[Float[torch.Tensor, "n_points n_features"]]
76+
) -> Float[torch.Tensor, "n_points n_classes"]:
77+
"""Compute the predicted classes for the given features.
78+
79+
Args:
80+
X: New inputs for which to make predictions.
81+
82+
Returns:
83+
X_proba: Predicted probabilities for the input features.
84+
"""
85+
if self.task == "regression":
86+
return self.predict_proba(X=X)
87+
return self.predict_proba(X=X).argmax(dim=-1)

0 commit comments

Comments
 (0)