Skip to content

Commit 1518a8d

Browse files
committed
Added Delta Hyperbolicity calculation
1 parent a488510 commit 1518a8d

File tree

1 file changed

+98
-3
lines changed

1 file changed

+98
-3
lines changed
Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,101 @@
11
"""Compute delta-hyperbolicity of a metric space."""
22

33

4-
def delta_hyperbolicity(dists):
5-
"""Estimate the delta hyperbolicity of a metric space using the method from https://arxiv.org/pdf/2204.08621"""
6-
raise NotImplementedError
4+
from jaxtyping import Float
5+
import torch
6+
7+
def sampled_delta_hyperbolicity(dismat, n_samples=1000, reference_idx=0):
8+
n = dismat.shape[0]
9+
# Sample n_samples triplets of points randomly
10+
indices = torch.randint(0, n, (n_samples, 3))
11+
12+
# Get gromov products
13+
# (j,k)_i = .5 (d(i,j) + d(i,k) - d(j,k))
14+
15+
x,y,z = indices.T
16+
w = reference_idx # set reference point
17+
18+
xy_w = .5 * (dismat[w,x] + dismat[w,y] - dismat[x,y])
19+
xz_w = .5 * (dismat[w,x] + dismat[w,z] - dismat[x,z])
20+
yz_w = .5 * (dismat[w,y] + dismat[w,z] - dismat[y,z])
21+
22+
# delta(x,y,z) = min((x,y)_w,(y-z)_w) - (x,z)_w
23+
deltas = torch.minimum(xy_w,yz_w) - xz_w
24+
diam = torch.max(dismat)
25+
rel_deltas = 2 * deltas / diam
26+
27+
return rel_deltas, indices
28+
29+
def iterative_delta_hyperbolicity(dismat):
30+
"""delta(x,y,z) = min((x,y)_w,(y-z)_w) - (x,z)_w"""
31+
n = dismat.shape[0]
32+
w = 0
33+
gromov_products = torch.zeros((n,n))
34+
deltas = torch.zeros((n,n,n))
35+
36+
# Get Gromov Products
37+
for x in range(n):
38+
for y in range(n):
39+
gromov_products[x,y] = gromov_product(w,x,y,dismat)
40+
41+
# Get Deltas
42+
for x in range(n):
43+
for y in range(n):
44+
for z in range(n):
45+
xz_w = gromov_products[x,z]
46+
xy_w = gromov_products[x,y]
47+
yz_w = gromov_products[y,z]
48+
deltas[x,y,z] = torch.minimum(xy_w,yz_w) - xz_w
49+
50+
diam = torch.max(dismat)
51+
rel_deltas = 2 * deltas / diam
52+
53+
return rel_deltas, gromov_products
54+
55+
56+
def gromov_product(i,j,k,dismat):
57+
"""(j,k)_i = 0.5 (d(i,j) + d(i,k) - d(j,k))"""
58+
d_ij = dismat[i,j]
59+
d_ik = dismat[i,k]
60+
d_jk = dismat[j,k]
61+
return 0.5 * (d_ij + d_ik - d_jk)
62+
63+
def delta_hyperbolicity(dismat: Float[torch.Tensor, "n_points n_points"], relative=True, device='cpu', full=False) -> Float[torch.Tensor, ""]:
64+
"""
65+
Compute the delta-hyperbolicity of a metric space.
66+
67+
Args:
68+
dismat: Distance matrix of the metric space.
69+
relative: Whether to return the relative delta-hyperbolicity.
70+
device: Device to run the computation on.
71+
full: Whether to return the full delta tensor or just the maximum delta.
72+
73+
Returns:
74+
delta: Delta-hyperbolicity of the metric space.
75+
"""
76+
77+
n = dismat.shape[0]
78+
p = 0
79+
80+
row = dismat[p, :].unsqueeze(0) # (1,N)
81+
col = dismat[:, p].unsqueeze(1) # (N,1)
82+
XY_p = 0.5 * (row + col - dismat)
83+
84+
XY_p_xy = XY_p.unsqueeze(2).expand(-1, -1, n) # (n,n,n)
85+
XY_p_yz = XY_p.unsqueeze(0).expand(n, -1, -1) # (n,n,n)
86+
XY_p_xz = XY_p.unsqueeze(1).expand(-1, n, -1) # (n,n,n)
87+
88+
out = torch.minimum(XY_p_xy, XY_p_yz)
89+
90+
if not full:
91+
delta = (out - XY_p_xz).max().item()
92+
else:
93+
delta = out - XY_p_xz
94+
95+
if relative:
96+
diam = torch.max(dismat).item()
97+
delta = 2 * delta / diam
98+
99+
return delta
100+
101+

0 commit comments

Comments
 (0)