Skip to content

Commit 24f9d93

Browse files
authored
Merge pull request #15 from pchlenski/curve_estimation
Curve estimation
2 parents 7443323 + ea6a735 commit 24f9d93

File tree

5 files changed

+629
-147
lines changed

5 files changed

+629
-147
lines changed

manify/curvature_estimation/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
* `sectional_curvature`: Estimates the sectional curvature of a graph from its distance matrix.
88
"""
99

10-
from manify.curvature_estimation.delta_hyperbolicity import sampled_delta_hyperbolicity, vectorized_delta_hyperbolicity
10+
from manify.curvature_estimation.delta_hyperbolicity import delta_hyperbolicity, sampled_delta_hyperbolicity
1111
from manify.curvature_estimation.greedy_method import greedy_signature_selection
1212
from manify.curvature_estimation.sectional_curvature import sectional_curvature
1313

1414
__all__ = [
1515
"greedy_signature_selection",
1616
"sectional_curvature",
17+
"delta_hyperbolicity",
1718
"sampled_delta_hyperbolicity",
18-
"vectorized_delta_hyperbolicity",
1919
]

manify/curvature_estimation/delta_hyperbolicity.py

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import TYPE_CHECKING
14+
from typing import TYPE_CHECKING, Any
1515

1616
import torch
1717

@@ -99,8 +99,70 @@ def vectorized_delta_hyperbolicity(
9999

100100
out = torch.minimum(XY_w_xy, XY_w_yz)
101101

102-
delta = out - XY_w_xz if full else (out - XY_w_xz).max().item()
103-
104-
delta = 2 * delta / torch.max(D) if relative else delta
102+
if full:
103+
delta = out - XY_w_xz
104+
if relative:
105+
max_dist = torch.max(D)
106+
if max_dist > 0:
107+
delta = 2 * delta / max_dist
108+
else:
109+
delta = torch.zeros_like(delta)
110+
else:
111+
delta = (out - XY_w_xz).max().item()
112+
if relative:
113+
max_dist = torch.max(D).item()
114+
if max_dist > 0:
115+
delta = 2 * delta / max_dist
116+
else:
117+
delta = 0.0
105118

106119
return delta
120+
121+
122+
def delta_hyperbolicity(
123+
distance_matrix: Float[torch.Tensor, "n_points n_points"],
124+
method: str = "global",
125+
**kwargs: Any
126+
) -> Float[torch.Tensor, "n_points"] | float:
127+
r"""Computes the δ-hyperbolicity from a distance matrix.
128+
129+
This function implements δ-hyperbolicity computation, which measures how close a metric
130+
space is to a tree. The value δ ≥ 0 is a global property; smaller values indicate
131+
the space is more hyperbolic (tree-like).
132+
133+
For each triplet of points (x,y,z) and reference point w, computes:
134+
δ(x,y,z) = min((x,y)_w, (y,z)_w) - (x,z)_w
135+
136+
where (a,b)_w = ½(d(w,a) + d(w,b) - d(a,b)) is the Gromov product.
137+
138+
Args:
139+
distance_matrix: Pairwise distance matrix as a torch.Tensor.
140+
method: Computation method. Options:
141+
- "sampled": Random sampling approach, returns array of δ values for sampled triplets
142+
- "global": Global maximum δ value over all triplets, returns single scalar
143+
- "full": Full δ tensor over all triplets, returns tensor of shape (n,n,n)
144+
**kwargs: Additional arguments passed to the computation function.
145+
For "sampled": n_samples, reference_idx, relative
146+
For "global"/"full": reference_idx, relative
147+
148+
Returns:
149+
delta_values: δ-hyperbolicity estimates.
150+
- "sampled": torch.Tensor of shape (n_samples,)
151+
- "global": float scalar (maximum δ value)
152+
- "full": torch.Tensor of shape (n_points, n_points, n_points)
153+
"""
154+
# Validate input
155+
if not isinstance(distance_matrix, torch.Tensor):
156+
raise TypeError(f"distance_matrix must be a torch.Tensor, got {type(distance_matrix)}")
157+
158+
D = distance_matrix.float()
159+
160+
if method == "sampled":
161+
deltas, _ = sampled_delta_hyperbolicity(D, **kwargs)
162+
return deltas
163+
elif method == "global":
164+
return vectorized_delta_hyperbolicity(D, full=False, **kwargs)
165+
elif method == "full":
166+
return vectorized_delta_hyperbolicity(D, full=True, **kwargs)
167+
else:
168+
raise ValueError(f"Unknown method: {method}. Choose 'sampled', 'global', 'full'")

0 commit comments

Comments
 (0)