Skip to content

Commit 3591857

Browse files
committed
Fix types everywere
1 parent 2b3657d commit 3591857

File tree

17 files changed

+289
-397
lines changed

17 files changed

+289
-397
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
fail-fast: false # don’t stop the matrix if one Python version fails
1414
matrix:
15-
python-version: ["3.9", "3.10", "3.11"]
15+
python-version: ["3.10", "3.11"] # jaxtyping requires >= 3.10; scipy requires < 3.12
1616

1717
steps:
1818
# Setup and installation

manify/curvature_estimation/delta_hyperbolicity.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def sampled_delta_hyperbolicity(dismat, n_samples=1000, reference_idx=0):
2929
return rel_deltas, indices
3030

3131

32-
def iterative_delta_hyperbolicity(dismat):
32+
def iterative_delta_hyperbolicity(
33+
dismat: Float[torch.Tensor, "n_points n_points"],
34+
) -> Float[torch.Tensor, "n_points n_points n_points"]:
3335
"""delta(x,y,z) = min((x,y)_w,(y-z)_w) - (x,z)_w"""
3436
n = dismat.shape[0]
3537
w = 0
@@ -56,7 +58,7 @@ def iterative_delta_hyperbolicity(dismat):
5658
return rel_deltas, gromov_products
5759

5860

59-
def gromov_product(i, j, k, dismat):
61+
def gromov_product(i: int, j: int, k: int, dismat: Float[torch.Tensor, "n_points n_points"]) -> float:
6062
"""(j,k)_i = 0.5 (d(i,j) + d(i,k) - d(j,k))"""
6163
d_ij = dismat[i, j]
6264
d_ik = dismat[i, k]
@@ -65,18 +67,14 @@ def gromov_product(i, j, k, dismat):
6567

6668

6769
def delta_hyperbolicity(
68-
dismat: Float[torch.Tensor, "n_points n_points"],
69-
relative=True,
70-
device="cpu",
71-
full=False,
70+
dismat: Float[torch.Tensor, "n_points n_points"], relative=True, full=False
7271
) -> Float[torch.Tensor, ""]:
7372
"""
7473
Compute the delta-hyperbolicity of a metric space.
7574
7675
Args:
7776
dismat: Distance matrix of the metric space.
7877
relative: Whether to return the relative delta-hyperbolicity.
79-
device: Device to run the computation on.
8078
full: Whether to return the full delta tensor or just the maximum delta.
8179
8280
Returns:

manify/curvature_estimation/sectional_curvature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def sample(D, size, n_samples=100):
6161
def estimate(D, size, n_samples):
6262
samples = sample(D, size, n_samples)
6363
m1 = np.mean(samples)
64-
m2 = np.mean(samples ** 2)
64+
m2 = np.mean(samples**2)
6565
return samples
6666

6767

manify/embedders/coordinate_learning.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def train_coords(
2424
pm: ProductManifold,
2525
dists: Float[torch.Tensor, "n_points n_points"],
26-
test_indices: Int[torch.Tensor, "n_test"] = torch.tensor([]),
26+
test_indices: Int[torch.Tensor, "n_test,"] = torch.tensor([]),
2727
device: str = "cpu",
2828
burn_in_learning_rate: float = 1e-3,
2929
burn_in_iterations: int = 2_000,
@@ -32,7 +32,6 @@ def train_coords(
3232
training_iterations: int = 18_000,
3333
loss_window_size: int = 100,
3434
logging_interval: int = 10,
35-
scale=1.0,
3635
) -> Tuple[Float[torch.Tensor, "n_points n_dim"], Dict[str, List[float]]]:
3736
"""
3837
Coordinate training and optimization
@@ -79,7 +78,7 @@ def train_coords(
7978
my_tqdm = tqdm(total=burn_in_iterations + training_iterations, leave=False)
8079

8180
# Outer training loop - mostly setting optimizer learning rates up here
82-
losses = {"train_train": [], "test_test": [], "train_test": [], "total": []}
81+
losses: Dict[str, List[float]] = {"train_train": [], "test_test": [], "train_test": [], "total": []}
8382

8483
# Actual training loop
8584
for i in range(burn_in_iterations + training_iterations):

manify/embedders/siamese.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
if decoder is not None:
2626
self.decoder = decoder
2727
else:
28-
self.decoder = lambda x: x
28+
self.decoder = torch.nn.Identity()
2929
self.decoder.requires_grad_(False)
3030
self.decoder.to(pm.device)
3131

manify/embedders/vae.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ def __init__(
4040
else:
4141
raise ValueError(f"Unknown reconstruction loss: {reconstruction_loss}")
4242

43-
def encode(
44-
self, x: Float[torch.Tensor, "batch_size n_features"]
45-
) -> Tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"],]:
43+
def encode(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Tuple[
44+
Float[torch.Tensor, "batch_size n_latent"],
45+
Float[torch.Tensor, "batch_size n_latent"],
46+
]:
4647
"""Must return z_mean, z_logvar"""
4748
z_mean_tangent, z_logvar = self.encoder(x)
4849
z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix # Adds zeros in the right places
@@ -53,9 +54,7 @@ def decode(self, z: Float[torch.Tensor, "batch_size n_latent"]) -> Float[torch.T
5354
"""Decoding in product space VAE"""
5455
return self.decoder(z)
5556

56-
def forward(
57-
self, x: Float[torch.Tensor, "batch_size n_features"]
58-
) -> Tuple[
57+
def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Tuple[
5958
Float[torch.Tensor, "batch_size n_features"],
6059
Float[torch.Tensor, "batch_size n_latent"],
6160
List[Float[torch.Tensor, "n_latent n_latent"]],
@@ -82,7 +81,7 @@ def kl_divergence(
8281
self,
8382
z_mean: Float[torch.Tensor, "batch_size n_latent"],
8483
sigma_factorized: List[Float[torch.Tensor, "n_latent n_latent"]],
85-
) -> Float[torch.Tensor, "batch_size"]:
84+
) -> Float[torch.Tensor, "batch_size,"]:
8685
"""
8786
Computes the KL divergence between posterior and prior distributions.
8887

manify/manifolds.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import annotations
1212

1313
import warnings
14-
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
14+
from typing import Callable, List, Literal, Optional, Tuple, Union
1515

1616
import geoopt
1717
import torch
@@ -33,13 +33,7 @@ class Manifold:
3333
stereographic: (bool) Whether to use stereographic coordinates for the manifold.
3434
"""
3535

36-
def __init__(
37-
self,
38-
curvature: float,
39-
dim: int,
40-
device: str = "cpu",
41-
stereographic: bool = False,
42-
):
36+
def __init__(self, curvature: float, dim: int, device: str = "cpu", stereographic: bool = False):
4337
# Device management
4438
self.device = device
4539

@@ -103,9 +97,7 @@ def to(self, device: str) -> "Manifold":
10397
return self
10498

10599
def inner(
106-
self,
107-
X: Float[torch.Tensor, "n_points1 n_dim"],
108-
Y: Float[torch.Tensor, "n_points2 n_dim"],
100+
self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
109101
) -> Float[torch.Tensor, "n_points1 n_points2"]:
110102
"""
111103
Compute the inner product of manifolds.
@@ -126,9 +118,7 @@ def inner(
126118
return X_fixed @ Y.T * scaler
127119

128120
def dist(
129-
self,
130-
X: Float[torch.Tensor, "n_points1 n_dim"],
131-
Y: Float[torch.Tensor, "n_points2 n_dim"],
121+
self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
132122
) -> Float[torch.Tensor, "n_points1 n_points2"]:
133123
"""
134124
Inherit distance function from the geoopt manifold.
@@ -143,9 +133,7 @@ def dist(
143133
return self.manifold.dist(X[:, None], Y[None, :])
144134

145135
def dist2(
146-
self,
147-
X: Float[torch.Tensor, "n_points1 n_dim"],
148-
Y: Float[torch.Tensor, "n_points2 n_dim"],
136+
self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
149137
) -> Float[torch.Tensor, "n_points1 n_points2"]:
150138
"""
151139
Inherit squared distance function from the geoopt manifold.
@@ -265,7 +253,7 @@ def log_likelihood(
265253
z: Float[torch.Tensor, "n_points n_ambient_dim"],
266254
mu: Optional[Float[torch.Tensor, "n_points n_ambient_dim"]] = None,
267255
sigma: Optional[Float[torch.Tensor, "n_points n_dim n_dim"]] = None,
268-
) -> Float[torch.Tensor, "n_points"]:
256+
) -> Float[torch.Tensor, "n_points,"]:
269257
"""
270258
Probability density function for WN(z ; mu, Sigma) in manifold
271259
@@ -321,9 +309,7 @@ def log_likelihood(
321309
return ll - (n - 1) * torch.log(R * torch.abs(sin_M(u_norm / R) / u_norm) + 1e-8)
322310

323311
def logmap(
324-
self,
325-
x: Float[torch.Tensor, "n_points n_dim"],
326-
base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None,
312+
self, x: Float[torch.Tensor, "n_points n_dim"], base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None
327313
) -> Float[torch.Tensor, "n_points n_dim"]:
328314
"""
329315
Logarithmic map of point on manifold x at base point.
@@ -341,9 +327,7 @@ def logmap(
341327
return self.manifold.logmap(x=base, y=x)
342328

343329
def expmap(
344-
self,
345-
u: Float[torch.Tensor, "n_points n_dim"],
346-
base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None,
330+
self, u: Float[torch.Tensor, "n_points n_dim"], base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None
347331
) -> Float[torch.Tensor, "n_points n_dim"]:
348332
"""
349333
Exponential map of tangent vector u at base point.
@@ -415,7 +399,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
415399
return orig_manifold, *points # type: ignore
416400

417401
# Inverse projection for points
418-
norm_squared = [(Y ** 2).sum(dim=1, keepdim=True) for Y in points]
402+
norm_squared = [(Y**2).sum(dim=1, keepdim=True) for Y in points]
419403
sign = torch.sign(self.curvature) # type: ignore
420404

421405
X0 = (1 + sign * norm_squared) / (1 - sign * norm_squared)
@@ -457,12 +441,7 @@ class ProductManifold(Manifold):
457441
stereographic: (bool) Whether to use stereographic coordinates for the manifold.
458442
"""
459443

460-
def __init__(
461-
self,
462-
signature: List[Tuple[float, int]],
463-
device: str = "cpu",
464-
stereographic: bool = False,
465-
):
444+
def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", stereographic: bool = False):
466445
# Device management
467446
self.device = device
468447

@@ -485,12 +464,7 @@ def __init__(
485464

486465
# Manifold <-> Dimension mapping
487466
self.ambient_dim, self.n_manifolds, self.dim = 0, 0, 0
488-
self.dim2man, self.man2dim, self.man2intrinsic, self.intrinsic2man = (
489-
{},
490-
{},
491-
{},
492-
{},
493-
)
467+
self.dim2man, self.man2dim, self.man2intrinsic, self.intrinsic2man = {}, {}, {}, {}
494468

495469
for M in self.P:
496470
for d in range(self.ambient_dim, self.ambient_dim + M.ambient_dim):
@@ -549,10 +523,7 @@ def sample(
549523
sigma_factorized: Optional[List[Float[torch.Tensor, "n_points n_dim_manifold n_dim_manifold"]]] = None,
550524
) -> Union[
551525
Float[torch.Tensor, "n_points n_ambient_dim"],
552-
Tuple[
553-
Float[torch.Tensor, "n_points n_ambient_dim"],
554-
Float[torch.Tensor, "n_points n_dim"],
555-
],
526+
Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points n_dim"]],
556527
]:
557528
"""
558529
Sample from the variational distribution.
@@ -593,9 +564,9 @@ def sample(
593564
def log_likelihood(
594565
self,
595566
z: Float[torch.Tensor, "batch_size n_dim"],
596-
mu: Optional[Float[torch.Tensor, "n_dim"]] = None,
567+
mu: Optional[Float[torch.Tensor, "n_dim,"]] = None,
597568
sigma_factorized: Optional[List[Float[torch.Tensor, "n_points n_dim_manifold n_dim_manifold"]]] = None,
598-
) -> Float[torch.Tensor, "batch_size"]:
569+
) -> Float[torch.Tensor, "batch_size,"]:
599570
"""
600571
Probability density function for WN(z ; mu, Sigma) in manifold
601572
@@ -624,7 +595,7 @@ def log_likelihood(
624595
]
625596
return torch.cat(component_lls, axis=1).sum(axis=1) # type: ignore
626597

627-
def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple[Manifold, ...]:
598+
def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple["ProductManifold", ...]:
628599
if self.is_stereographic:
629600
print("Manifold is already in stereographic coordinates.")
630601
return self, *points # type: ignore
@@ -653,7 +624,7 @@ def gaussian_mixture(
653624
regression_noise_std: float = 0.1,
654625
task: Literal["classification", "regression"] = "classification",
655626
adjust_for_dims: bool = False,
656-
) -> Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points"]]:
627+
) -> Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points,"]]:
657628
"""
658629
Generate a set of labeled samples from a Gaussian mixture model.
659630

0 commit comments

Comments
 (0)