Skip to content

Commit 8f585d2

Browse files
committed
Many typing fixes
1 parent 5147a81 commit 8f585d2

File tree

11 files changed

+940
-124
lines changed

11 files changed

+940
-124
lines changed

manify/curvature_estimation/delta_hyperbolicity.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44

55
import torch
66
from jaxtyping import Float
7+
from typing import Tuple
78

89

9-
def sampled_delta_hyperbolicity(dismat, n_samples=1000, reference_idx=0):
10-
n = dismat.shape[0]
10+
def sampled_delta_hyperbolicity(
11+
D: Float[torch.Tensor, "n_points n_points"], n_samples: int = 1000, reference_idx: int = 0
12+
) -> Tuple[Float[torch.Tensor, "n_samples,"], Float[torch.Tensor, "n_samples 3"]]:
13+
n = D.shape[0]
1114
# Sample n_samples triplets of points randomly
1215
indices = torch.randint(0, n, (n_samples, 3))
1316

@@ -17,31 +20,31 @@ def sampled_delta_hyperbolicity(dismat, n_samples=1000, reference_idx=0):
1720
x, y, z = indices.T
1821
w = reference_idx # set reference point
1922

20-
xy_w = 0.5 * (dismat[w, x] + dismat[w, y] - dismat[x, y])
21-
xz_w = 0.5 * (dismat[w, x] + dismat[w, z] - dismat[x, z])
22-
yz_w = 0.5 * (dismat[w, y] + dismat[w, z] - dismat[y, z])
23+
xy_w = 0.5 * (D[w, x] + D[w, y] - D[x, y])
24+
xz_w = 0.5 * (D[w, x] + D[w, z] - D[x, z])
25+
yz_w = 0.5 * (D[w, y] + D[w, z] - D[y, z])
2326

2427
# delta(x,y,z) = min((x,y)_w,(y-z)_w) - (x,z)_w
2528
deltas = torch.minimum(xy_w, yz_w) - xz_w
26-
diam = torch.max(dismat)
29+
diam = torch.max(D)
2730
rel_deltas = 2 * deltas / diam
2831

2932
return rel_deltas, indices
3033

3134

3235
def iterative_delta_hyperbolicity(
33-
dismat: Float[torch.Tensor, "n_points n_points"],
36+
D: Float[torch.Tensor, "n_points n_points"], reference_idx: int = 0
3437
) -> Float[torch.Tensor, "n_points n_points n_points"]:
3538
"""delta(x,y,z) = min((x,y)_w,(y-z)_w) - (x,z)_w"""
36-
n = dismat.shape[0]
37-
w = 0
39+
n = D.shape[0]
40+
w = reference_idx
3841
gromov_products = torch.zeros((n, n))
3942
deltas = torch.zeros((n, n, n))
4043

4144
# Get Gromov Products
4245
for x in range(n):
4346
for y in range(n):
44-
gromov_products[x, y] = gromov_product(w, x, y, dismat)
47+
gromov_products[x, y] = gromov_product(w, x, y, D)
4548

4649
# Get Deltas
4750
for x in range(n):
@@ -52,55 +55,52 @@ def iterative_delta_hyperbolicity(
5255
yz_w = gromov_products[y, z]
5356
deltas[x, y, z] = torch.minimum(xy_w, yz_w) - xz_w
5457

55-
diam = torch.max(dismat)
58+
diam = torch.max(D)
5659
rel_deltas = 2 * deltas / diam
5760

5861
return rel_deltas, gromov_products
5962

6063

61-
def gromov_product(i: int, j: int, k: int, dismat: Float[torch.Tensor, "n_points n_points"]) -> float:
64+
def gromov_product(i: int, j: int, k: int, D: Float[torch.Tensor, "n_points n_points"]) -> float:
6265
"""(j,k)_i = 0.5 (d(i,j) + d(i,k) - d(j,k))"""
63-
d_ij = dismat[i, j]
64-
d_ik = dismat[i, k]
65-
d_jk = dismat[j, k]
66-
return 0.5 * (d_ij + d_ik - d_jk)
66+
return float(0.5 * (D[i, j] + D[i, k] - D[j, k]))
6767

6868

6969
def delta_hyperbolicity(
70-
dismat: Float[torch.Tensor, "n_points n_points"], relative=True, full=False
70+
D: Float[torch.Tensor, "n_points n_points"], reference_idx: int = 0, relative: bool = True, full: bool = False
7171
) -> Float[torch.Tensor, ""]:
7272
"""
7373
Compute the delta-hyperbolicity of a metric space.
7474
7575
Args:
76-
dismat: Distance matrix of the metric space.
76+
D: Distance matrix of the metric space.
7777
relative: Whether to return the relative delta-hyperbolicity.
7878
full: Whether to return the full delta tensor or just the maximum delta.
7979
8080
Returns:
8181
delta: Delta-hyperbolicity of the metric space.
8282
"""
8383

84-
n = dismat.shape[0]
85-
p = 0
84+
n = D.shape[0]
85+
w = reference_idx
8686

87-
row = dismat[p, :].unsqueeze(0) # (1,N)
88-
col = dismat[:, p].unsqueeze(1) # (N,1)
89-
XY_p = 0.5 * (row + col - dismat)
87+
row = D[w, :].unsqueeze(0) # (1,N)
88+
col = D[:, w].unsqueeze(1) # (N,1)
89+
XY_w = 0.5 * (row + col - D)
9090

91-
XY_p_xy = XY_p.unsqueeze(2).expand(-1, -1, n) # (n,n,n)
92-
XY_p_yz = XY_p.unsqueeze(0).expand(n, -1, -1) # (n,n,n)
93-
XY_p_xz = XY_p.unsqueeze(1).expand(-1, n, -1) # (n,n,n)
91+
XY_w_xy = XY_w.unsqueeze(2).expand(-1, -1, n) # (n,n,n)
92+
XY_w_yz = XY_w.unsqueeze(0).expand(n, -1, -1) # (n,n,n)
93+
XY_w_xz = XY_w.unsqueeze(1).expand(-1, n, -1) # (n,n,n)
9494

95-
out = torch.minimum(XY_p_xy, XY_p_yz)
95+
out = torch.minimum(XY_w_xy, XY_w_yz)
9696

9797
if not full:
98-
delta = (out - XY_p_xz).max().item()
98+
delta = (out - XY_w_xz).max().item()
9999
else:
100-
delta = out - XY_p_xz
100+
delta = out - XY_w_xz
101101

102102
if relative:
103-
diam = torch.max(dismat).item()
103+
diam = torch.max(D).item()
104104
delta = 2 * delta / diam
105105

106106
return delta

manify/curvature_estimation/greedy_method.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Tuple
5+
from typing import Tuple, Any
66

77
import torch
88

@@ -12,12 +12,8 @@
1212
def greedy_curvature_method(
1313
pm: ProductManifold,
1414
dists: torch.Tensor,
15-
candidate_components: Tuple[Tuple[float, int], ...] = (
16-
(-1.0, 2),
17-
(0.0, 2),
18-
(1.0, 2),
19-
),
15+
candidate_components: Tuple[Tuple[float, int], ...] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
2016
max_components: int = 3,
21-
):
17+
) -> Any:
2218
"""The greedy curvature estimation method from Tabaghi et al. at https://arxiv.org/pdf/2102.10204"""
2319
raise NotImplementedError

manify/manifolds.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,7 @@ def sample(
196196
sigma: Optional[Float[torch.Tensor, "n_points n_dim n_dim"]] = None,
197197
) -> Union[
198198
Float[torch.Tensor, "n_points n_ambient_dim"],
199-
Tuple[
200-
Float[torch.Tensor, "n_points n_ambient_dim"],
201-
Float[torch.Tensor, "n_points n_dim"],
202-
],
199+
Tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points n_dim"]],
203200
]:
204201
"""
205202
Sample from the variational distribution.
@@ -228,7 +225,7 @@ def sample(
228225
N = torch.distributions.MultivariateNormal(
229226
loc=torch.zeros((n, self.dim), device=self.device), covariance_matrix=sigma
230227
)
231-
v = N.sample() # type: ignore
228+
v = N.sample()
232229

233230
# Don't need to adjust normal vectors for the Scaled manifold class in geoopt - very cool!
234231

@@ -356,14 +353,14 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
356353

357354
if self.is_stereographic:
358355
print("Manifold is already in stereographic coordinates.")
359-
return self, *points # type: ignore
356+
return self, *points
360357

361358
# Convert manifold
362359
stereo_manifold = Manifold(self.curvature, self.dim, device=self.device, stereographic=True)
363360

364361
# Euclidean edge case
365362
if self.type == "E":
366-
return stereo_manifold, *points # type: ignore
363+
return stereo_manifold, *points
367364

368365
# Convert points
369366
num = [X[:, 1:] for X in points]
@@ -373,7 +370,7 @@ def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple
373370
stereo_points = [n / d for n, d in zip(num, denom)]
374371
assert all([stereo_manifold.manifold.check_point(X) for X in stereo_points])
375372

376-
return stereo_manifold, *stereo_points # type: ignore
373+
return stereo_manifold, *stereo_points
377374

378375
def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_stereo"]) -> Tuple["Manifold", ...]:
379376
"""
@@ -389,14 +386,14 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
389386
"""
390387
if not self.is_stereographic:
391388
print("Manifold is already in original coordinates.")
392-
return self, *points # type: ignore
389+
return self, *points
393390

394391
# Convert manifold
395392
orig_manifold = Manifold(self.curvature, self.dim, device=self.device, stereographic=False)
396393

397394
# Euclidean edge case
398395
if self.type == "E":
399-
return orig_manifold, *points # type: ignore
396+
return orig_manifold, *points
400397

401398
# Inverse projection for points
402399
out = []
@@ -427,7 +424,7 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
427424

428425
out.append(inv_points)
429426

430-
return orig_manifold, *out # type: ignore
427+
return orig_manifold, *out
431428

432429
def apply(self, f: Callable) -> Callable:
433430
"""
@@ -475,11 +472,11 @@ def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", ster
475472
# Actually initialize the geoopt manifolds; other derived properties
476473
self.P = [Manifold(curvature, dim, device=device, stereographic=stereographic) for curvature, dim in signature]
477474
manifold_class = geoopt.StereographicProductManifold if stereographic else geoopt.ProductManifold
478-
self.manifold = manifold_class(*[(M.manifold, M.ambient_dim) for M in self.P]).to(device) # type: ignore
475+
self.manifold = manifold_class(*[(M.manifold, M.ambient_dim) for M in self.P]).to(device)
479476
self.name = " x ".join([M.name for M in self.P])
480477

481478
# Origin
482-
self.mu0 = torch.cat([M.mu0 for M in self.P], axis=1).to(self.device) # type: ignore
479+
self.mu0 = torch.cat([M.mu0 for M in self.P], axis=1).to(self.device)
483480

484481
# Manifold <-> Dimension mapping
485482
self.ambient_dim, self.n_manifolds, self.dim = 0, 0, 0
@@ -507,11 +504,11 @@ def __init__(self, signature: List[Tuple[float, int]], device: str = "cpu", ster
507504
for j, k in zip(intrinsic_dims, ambient_dims[-len(intrinsic_dims) :]):
508505
self.projection_matrix[j, k] = 1.0
509506

510-
def params(self):
507+
def params(self) -> List[float]:
511508
"""Returns scales for all component manifolds"""
512509
return [x.scale() for x in self.manifold.manifolds]
513510

514-
def to(self, device: str):
511+
def to(self, device: str) -> "ProductManifold":
515512
"""Move all components to a new device"""
516513
self.device = device
517514
self.P = [M.to(device) for M in self.P]
@@ -628,24 +625,23 @@ def log_likelihood(
628625
M.log_likelihood(z_M, mu_M, sigma_M).unsqueeze(dim=1)
629626
for M, z_M, mu_M, sigma_M in zip(self.P, z_factorized, mu_factorized, sigma_factorized)
630627
]
631-
return torch.cat(component_lls, axis=1).sum(axis=1) # type: ignore
628+
return torch.cat(component_lls, axis=1).sum(axis=1)
632629

633630
def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> Tuple["ProductManifold", ...]:
634631
if self.is_stereographic:
635632
print("Manifold is already in stereographic coordinates.")
636-
return self, *points # type: ignore
633+
return self, *points
637634

638635
# Convert manifold
639636
stereo_manifold = ProductManifold(self.signature, device=self.device, stereographic=True)
640637

641638
# Convert points
642639
stereo_points = [
643-
torch.hstack([M.stereographic(x)[1] for x, M in zip(self.factorize(X), self.P)]) # type: ignore
644-
for X in points
640+
torch.hstack([M.stereographic(x)[1] for x, M in zip(self.factorize(X), self.P)]) for X in points
645641
]
646642
assert all([stereo_manifold.manifold.check_point(X) for X in stereo_points])
647643

648-
return stereo_manifold, *stereo_points # type: ignore
644+
return stereo_manifold, *stereo_points
649645

650646
def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_stereo"]) -> Tuple[Manifold]:
651647
if not self.is_stereographic:
@@ -660,9 +656,9 @@ def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_ste
660656
]
661657
assert all([orig_manifold.manifold.check_point(X) for X in orig_points])
662658

663-
return orig_manifold, *orig_points # type: ignore
659+
return orig_manifold, *orig_points
664660

665-
@torch.no_grad()
661+
@torch.no_grad() # type: ignore
666662
def gaussian_mixture(
667663
self,
668664
num_points: int = 1_000,
@@ -713,7 +709,7 @@ def gaussian_mixture(
713709
z_mean=torch.stack([self.mu0] * num_clusters),
714710
sigma_factorized=[torch.stack([torch.eye(M.dim)] * num_clusters) * cov_scale_means for M in self.P],
715711
)
716-
assert cluster_means.shape == (num_clusters, self.ambient_dim) # type: ignore
712+
assert cluster_means.shape == (num_clusters, self.ambient_dim)
717713

718714
# Generate class assignments
719715
cluster_probs = torch.rand(num_clusters)
@@ -726,10 +722,8 @@ def gaussian_mixture(
726722

727723
# Generate covariance matrices for each class - Wishart distribution
728724
cov_matrices = [
729-
torch.distributions.Wishart(
730-
df=M.dim + 1, covariance_matrix=torch.eye(M.dim) * cov_scale_points # type: ignore
731-
).sample(
732-
sample_shape=(num_clusters,) # type: ignore
725+
torch.distributions.Wishart(df=M.dim + 1, covariance_matrix=torch.eye(M.dim) * cov_scale_points).sample(
726+
sample_shape=(num_clusters,)
733727
)
734728
+ torch.eye(M.dim) * 1e-5 # jitter to avoid singularity
735729
for M in self.P
@@ -764,7 +758,7 @@ def gaussian_mixture(
764758

765759
# Noise component
766760
N = torch.distributions.Normal(0, regression_noise_std)
767-
v = N.sample((num_points,)).to(self.device) # type: ignore
761+
v = N.sample((num_points,)).to(self.device)
768762
labels += v
769763

770764
# Normalize regression labels to range [0, 1] so that RMSE can be more easily interpreted

0 commit comments

Comments
 (0)