Skip to content
4 changes: 2 additions & 2 deletions manify/curvature_estimation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
* `sectional_curvature`: Estimates the sectional curvature of a graph from its distance matrix.
"""

from manify.curvature_estimation.delta_hyperbolicity import sampled_delta_hyperbolicity, vectorized_delta_hyperbolicity
from manify.curvature_estimation.delta_hyperbolicity import delta_hyperbolicity, sampled_delta_hyperbolicity
from manify.curvature_estimation.greedy_method import greedy_signature_selection
from manify.curvature_estimation.sectional_curvature import sectional_curvature

__all__ = [
"greedy_signature_selection",
"sectional_curvature",
"delta_hyperbolicity",
"sampled_delta_hyperbolicity",
"vectorized_delta_hyperbolicity",
]
70 changes: 66 additions & 4 deletions manify/curvature_estimation/delta_hyperbolicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import torch

Expand Down Expand Up @@ -99,8 +99,70 @@ def vectorized_delta_hyperbolicity(

out = torch.minimum(XY_w_xy, XY_w_yz)

delta = out - XY_w_xz if full else (out - XY_w_xz).max().item()

delta = 2 * delta / torch.max(D) if relative else delta
if full:
delta = out - XY_w_xz
if relative:
max_dist = torch.max(D)
if max_dist > 0:
delta = 2 * delta / max_dist
else:
delta = torch.zeros_like(delta)
else:
delta = (out - XY_w_xz).max().item()
if relative:
max_dist = torch.max(D).item()
if max_dist > 0:
delta = 2 * delta / max_dist
else:
delta = 0.0

return delta


def delta_hyperbolicity(
distance_matrix: Float[torch.Tensor, "n_points n_points"],
method: str = "global",
**kwargs: Any
) -> Float[torch.Tensor, "n_points"] | float:
r"""Computes the δ-hyperbolicity from a distance matrix.

This function implements δ-hyperbolicity computation, which measures how close a metric
space is to a tree. The value δ ≥ 0 is a global property; smaller values indicate
the space is more hyperbolic (tree-like).

For each triplet of points (x,y,z) and reference point w, computes:
δ(x,y,z) = min((x,y)_w, (y,z)_w) - (x,z)_w

where (a,b)_w = ½(d(w,a) + d(w,b) - d(a,b)) is the Gromov product.

Args:
distance_matrix: Pairwise distance matrix as a torch.Tensor.
method: Computation method. Options:
- "sampled": Random sampling approach, returns array of δ values for sampled triplets
- "global": Global maximum δ value over all triplets, returns single scalar
- "full": Full δ tensor over all triplets, returns tensor of shape (n,n,n)
**kwargs: Additional arguments passed to the computation function.
For "sampled": n_samples, reference_idx, relative
For "global"/"full": reference_idx, relative

Returns:
delta_values: δ-hyperbolicity estimates.
- "sampled": torch.Tensor of shape (n_samples,)
- "global": float scalar (maximum δ value)
- "full": torch.Tensor of shape (n_points, n_points, n_points)
"""
# Validate input
if not isinstance(distance_matrix, torch.Tensor):
raise TypeError(f"distance_matrix must be a torch.Tensor, got {type(distance_matrix)}")

D = distance_matrix.float()

if method == "sampled":
deltas, _ = sampled_delta_hyperbolicity(D, **kwargs)
return deltas
elif method == "global":
return vectorized_delta_hyperbolicity(D, full=False, **kwargs)
elif method == "full":
return vectorized_delta_hyperbolicity(D, full=True, **kwargs)
else:
raise ValueError(f"Unknown method: {method}. Choose 'sampled', 'global', 'full'")
Loading
Loading