Skip to content

Commit 773bb0e

Browse files
committed
future annotations + isort + black
1 parent f76bd7a commit 773bb0e

21 files changed

+337
-139
lines changed

manify/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import manify.curvature_estimation
22
import manify.embedders
3+
import manify.manifolds
34
import manify.predictors
45
import manify.utils
5-
6-
import manify.manifolds
Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Compute delta-hyperbolicity of a metric space."""
22

3+
from __future__ import annotations
34

4-
from jaxtyping import Float
55
import torch
6+
from jaxtyping import Float
7+
68

79
def sampled_delta_hyperbolicity(dismat, n_samples=1000, reference_idx=0):
810
n = dismat.shape[0]
@@ -12,55 +14,62 @@ def sampled_delta_hyperbolicity(dismat, n_samples=1000, reference_idx=0):
1214
# Get gromov products
1315
# (j,k)_i = .5 (d(i,j) + d(i,k) - d(j,k))
1416

15-
x,y,z = indices.T
16-
w = reference_idx # set reference point
17+
x, y, z = indices.T
18+
w = reference_idx # set reference point
1719

18-
xy_w = .5 * (dismat[w,x] + dismat[w,y] - dismat[x,y])
19-
xz_w = .5 * (dismat[w,x] + dismat[w,z] - dismat[x,z])
20-
yz_w = .5 * (dismat[w,y] + dismat[w,z] - dismat[y,z])
20+
xy_w = 0.5 * (dismat[w, x] + dismat[w, y] - dismat[x, y])
21+
xz_w = 0.5 * (dismat[w, x] + dismat[w, z] - dismat[x, z])
22+
yz_w = 0.5 * (dismat[w, y] + dismat[w, z] - dismat[y, z])
2123

2224
# delta(x,y,z) = min((x,y)_w,(y-z)_w) - (x,z)_w
23-
deltas = torch.minimum(xy_w,yz_w) - xz_w
25+
deltas = torch.minimum(xy_w, yz_w) - xz_w
2426
diam = torch.max(dismat)
2527
rel_deltas = 2 * deltas / diam
2628

2729
return rel_deltas, indices
2830

31+
2932
def iterative_delta_hyperbolicity(dismat):
3033
"""delta(x,y,z) = min((x,y)_w,(y-z)_w) - (x,z)_w"""
3134
n = dismat.shape[0]
3235
w = 0
33-
gromov_products = torch.zeros((n,n))
34-
deltas = torch.zeros((n,n,n))
36+
gromov_products = torch.zeros((n, n))
37+
deltas = torch.zeros((n, n, n))
3538

3639
# Get Gromov Products
3740
for x in range(n):
3841
for y in range(n):
39-
gromov_products[x,y] = gromov_product(w,x,y,dismat)
42+
gromov_products[x, y] = gromov_product(w, x, y, dismat)
4043

4144
# Get Deltas
4245
for x in range(n):
4346
for y in range(n):
4447
for z in range(n):
45-
xz_w = gromov_products[x,z]
46-
xy_w = gromov_products[x,y]
47-
yz_w = gromov_products[y,z]
48-
deltas[x,y,z] = torch.minimum(xy_w,yz_w) - xz_w
49-
48+
xz_w = gromov_products[x, z]
49+
xy_w = gromov_products[x, y]
50+
yz_w = gromov_products[y, z]
51+
deltas[x, y, z] = torch.minimum(xy_w, yz_w) - xz_w
52+
5053
diam = torch.max(dismat)
5154
rel_deltas = 2 * deltas / diam
5255

5356
return rel_deltas, gromov_products
5457

5558

56-
def gromov_product(i,j,k,dismat):
59+
def gromov_product(i, j, k, dismat):
5760
"""(j,k)_i = 0.5 (d(i,j) + d(i,k) - d(j,k))"""
58-
d_ij = dismat[i,j]
59-
d_ik = dismat[i,k]
60-
d_jk = dismat[j,k]
61+
d_ij = dismat[i, j]
62+
d_ik = dismat[i, k]
63+
d_jk = dismat[j, k]
6164
return 0.5 * (d_ij + d_ik - d_jk)
6265

63-
def delta_hyperbolicity(dismat: Float[torch.Tensor, "n_points n_points"], relative=True, device='cpu', full=False) -> Float[torch.Tensor, ""]:
66+
67+
def delta_hyperbolicity(
68+
dismat: Float[torch.Tensor, "n_points n_points"],
69+
relative=True,
70+
device="cpu",
71+
full=False,
72+
) -> Float[torch.Tensor, ""]:
6473
"""
6574
Compute the delta-hyperbolicity of a metric space.
6675
@@ -69,7 +78,7 @@ def delta_hyperbolicity(dismat: Float[torch.Tensor, "n_points n_points"], relati
6978
relative: Whether to return the relative delta-hyperbolicity.
7079
device: Device to run the computation on.
7180
full: Whether to return the full delta tensor or just the maximum delta.
72-
81+
7382
Returns:
7483
delta: Delta-hyperbolicity of the metric space.
7584
"""
@@ -95,7 +104,5 @@ def delta_hyperbolicity(dismat: Float[torch.Tensor, "n_points n_points"], relati
95104
if relative:
96105
diam = torch.max(dismat).item()
97106
delta = 2 * delta / diam
98-
99-
return delta
100107

101-
108+
return delta

manify/curvature_estimation/greedy_method.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Greedy selection of signatures, as described in Tabaghi et al. at https://arxiv.org/pdf/2102.10204"""
22

3+
from __future__ import annotations
4+
35
from typing import Tuple
46

57
import torch
@@ -10,7 +12,11 @@
1012
def greedy_curvature_method(
1113
pm: ProductManifold,
1214
dists: torch.Tensor,
13-
candidate_components: Tuple[Tuple[float, int], ...] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
15+
candidate_components: Tuple[Tuple[float, int], ...] = (
16+
(-1.0, 2),
17+
(0.0, 2),
18+
(1.0, 2),
19+
),
1420
max_components: int = 3,
1521
):
1622
"""The greedy curvature estimation method from Tabaghi et al. at https://arxiv.org/pdf/2102.10204"""

manify/curvature_estimation/sectional_curvature.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from __future__ import annotations
2+
13
import random
24

35
import numpy as np
46

5-
67
# The next couple functions are taken from this repo:
78
# https://github.com/HazyResearch/hyperbolics
89
# Paper: https://openreview.net/pdf?id=HJxeWnCcF7
@@ -60,7 +61,7 @@ def sample(D, size, n_samples=100):
6061
def estimate(D, size, n_samples):
6162
samples = sample(D, size, n_samples)
6263
m1 = np.mean(samples)
63-
m2 = np.mean(samples**2)
64+
m2 = np.mean(samples ** 2)
6465
return samples
6566

6667

manify/embedders/coordinate_learning.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
"""Implementation for coordinate training and optimization"""
22

3+
from __future__ import annotations
4+
35
import sys
4-
from typing import Tuple, List, Dict
5-
from jaxtyping import Float, Int
6+
from typing import Dict, List, Tuple
67

8+
import geoopt
79
import numpy as np
810
import torch
9-
import geoopt
11+
from jaxtyping import Float, Int
1012

11-
from .losses import distortion_loss, d_avg
1213
from ..manifolds import ProductManifold
14+
from .losses import d_avg, distortion_loss
1315

1416
# TQDM: notebook or regular
1517
if "ipykernel" in sys.modules:

manify/embedders/losses.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Implementation of different measurement metrics"""
22

3+
from __future__ import annotations
4+
35
from typing import List
4-
from jaxtyping import Float
56

6-
import torch
77
import networkx as nx
8+
import torch
9+
from jaxtyping import Float
810

911
from ..manifolds import ProductManifold
1012

@@ -24,7 +26,7 @@ def distortion_loss(
2426
Returns:
2527
float: A float indicating the distortion loss, calculated as the sum of the squared relative
2628
errors between the estimated and true squared distances.
27-
29+
2830
See also: square_loss in HazyResearch hyperbolics repo:
2931
https://github.com/HazyResearch/hyperbolics/blob/master/pytorch/hyperbolic_models.py#L178
3032
"""

manify/embedders/siamese.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Siamese network embedder"""
22

3-
from typing import List, Optional, Literal
4-
from jaxtyping import Float
3+
from __future__ import annotations
4+
5+
from typing import Optional
6+
57
import torch
8+
from jaxtyping import Float
69

710
from ..manifolds import ProductManifold
811

manify/embedders/vae.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Product space variational autoencoder implementation"""
22

3+
from __future__ import annotations
4+
35
from typing import List, Tuple
4-
from jaxtyping import Float
56

67
import torch
8+
from jaxtyping import Float
79

810
from ..manifolds import ProductManifold
911

@@ -40,7 +42,7 @@ def __init__(
4042

4143
def encode(
4244
self, x: Float[torch.Tensor, "batch_size n_features"]
43-
) -> Tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]:
45+
) -> Tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"],]:
4446
"""Must return z_mean, z_logvar"""
4547
z_mean_tangent, z_logvar = self.encoder(x)
4648
z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix # Adds zeros in the right places
@@ -51,7 +53,9 @@ def decode(self, z: Float[torch.Tensor, "batch_size n_latent"]) -> Float[torch.T
5153
"""Decoding in product space VAE"""
5254
return self.decoder(z)
5355

54-
def forward(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Tuple[
56+
def forward(
57+
self, x: Float[torch.Tensor, "batch_size n_features"]
58+
) -> Tuple[
5559
Float[torch.Tensor, "batch_size n_features"],
5660
Float[torch.Tensor, "batch_size n_latent"],
5761
List[Float[torch.Tensor, "n_latent n_latent"]],

manify/manifolds.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
includes functions for different key geometric operations.
99
"""
1010

11+
from __future__ import annotations
12+
1113
import warnings
12-
from typing import Callable, List, Literal, Optional, Tuple, Union
14+
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
1315

1416
import geoopt
1517
import torch
@@ -31,7 +33,13 @@ class Manifold:
3133
stereographic: (bool) Whether to use stereographic coordinates for the manifold.
3234
"""
3335

34-
def __init__(self, curvature: float, dim: int, device: str = "cpu", stereographic: bool = False):
36+
def __init__(
37+
self,
38+
curvature: float,
39+
dim: int,
40+
device: str = "cpu",
41+
stereographic: bool = False,
42+
):
3543
# Device management
3644
self.device = device
3745

@@ -95,7 +103,9 @@ def to(self, device: str) -> "Manifold":
95103
return self
96104

97105
def inner(
98-
self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
106+
self,
107+
X: Float[torch.Tensor, "n_points1 n_dim"],
108+
Y: Float[torch.Tensor, "n_points2 n_dim"],
99109
) -> Float[torch.Tensor, "n_points1 n_points2"]:
100110
"""
101111
Compute the inner product of manifolds.
@@ -116,7 +126,9 @@ def inner(
116126
return X_fixed @ Y.T * scaler
117127

118128
def dist(
119-
self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
129+
self,
130+
X: Float[torch.Tensor, "n_points1 n_dim"],
131+
Y: Float[torch.Tensor, "n_points2 n_dim"],
120132
) -> Float[torch.Tensor, "n_points1 n_points2"]:
121133
"""
122134
Inherit distance function from the geoopt manifold.
@@ -131,7 +143,9 @@ def dist(
131143
return self.manifold.dist(X[:, None], Y[None, :])
132144

133145
def dist2(
134-
self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
146+
self,
147+
X: Float[torch.Tensor, "n_points1 n_dim"],
148+
Y: Float[torch.Tensor, "n_points2 n_dim"],
135149
) -> Float[torch.Tensor, "n_points1 n_points2"]:
136150
"""
137151
Inherit squared distance function from the geoopt manifold.
@@ -194,7 +208,10 @@ def sample(
194208
sigma: Optional[Float[torch.Tensor, "n_points n_dim n_dim"]] = None,
195209
) -> Union[
196210
Float[torch.Tensor, "n_points n_ambient_dim"],
197-
Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points n_dim"]],
211+
Tuple[
212+
Float[torch.Tensor, "n_points n_ambient_dim"],
213+
Float[torch.Tensor, "n_points n_dim"],
214+
],
198215
]:
199216
"""
200217
Sample from the variational distribution.
@@ -304,7 +321,9 @@ def log_likelihood(
304321
return ll - (n - 1) * torch.log(R * torch.abs(sin_M(u_norm / R) / u_norm) + 1e-8)
305322

306323
def logmap(
307-
self, x: Float[torch.Tensor, "n_points n_dim"], base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None
324+
self,
325+
x: Float[torch.Tensor, "n_points n_dim"],
326+
base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None,
308327
) -> Float[torch.Tensor, "n_points n_dim"]:
309328
"""
310329
Logarithmic map of point on manifold x at base point.
@@ -322,7 +341,9 @@ def logmap(
322341
return self.manifold.logmap(x=base, y=x)
323342

324343
def expmap(
325-
self, u: Float[torch.Tensor, "n_points n_dim"], base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None
344+
self,
345+
u: Float[torch.Tensor, "n_points n_dim"],
346+
base: Optional[Float[torch.Tensor, "n_points n_dim"]] = None,
326347
) -> Float[torch.Tensor, "n_points n_dim"]:
327348
"""
328349
Exponential map of tangent vector u at base point.
@@ -394,7 +415,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
394415
return orig_manifold, *points # type: ignore
395416

396417
# Inverse projection for points
397-
norm_squared = [(Y**2).sum(dim=1, keepdim=True) for Y in points]
418+
norm_squared = [(Y ** 2).sum(dim=1, keepdim=True) for Y in points]
398419
sign = torch.sign(self.curvature) # type: ignore
399420

400421
X0 = (1 + sign * norm_squared) / (1 - sign * norm_squared)
@@ -436,7 +457,12 @@ class ProductManifold(Manifold):
436457
stereographic: (bool) Whether to use stereographic coordinates for the manifold.
437458
"""
438459

439-
def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", stereographic: bool = False):
460+
def __init__(
461+
self,
462+
signature: List[Tuple[float, int]],
463+
device: str = "cpu",
464+
stereographic: bool = False,
465+
):
440466
# Device management
441467
self.device = device
442468

@@ -459,7 +485,12 @@ def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", ster
459485

460486
# Manifold <-> Dimension mapping
461487
self.ambient_dim, self.n_manifolds, self.dim = 0, 0, 0
462-
self.dim2man, self.man2dim, self.man2intrinsic, self.intrinsic2man = {}, {}, {}, {}
488+
self.dim2man, self.man2dim, self.man2intrinsic, self.intrinsic2man = (
489+
{},
490+
{},
491+
{},
492+
{},
493+
)
463494

464495
for M in self.P:
465496
for d in range(self.ambient_dim, self.ambient_dim + M.ambient_dim):
@@ -518,7 +549,10 @@ def sample(
518549
sigma_factorized: Optional[List[Float[torch.Tensor, "n_points n_dim_manifold n_dim_manifold"]]] = None,
519550
) -> Union[
520551
Float[torch.Tensor, "n_points n_ambient_dim"],
521-
Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points n_dim"]],
552+
Tuple[
553+
Float[torch.Tensor, "n_points n_ambient_dim"],
554+
Float[torch.Tensor, "n_points n_dim"],
555+
],
522556
]:
523557
"""
524558
Sample from the variational distribution.

0 commit comments

Comments
 (0)