|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
14 | | -from typing import TYPE_CHECKING |
| 14 | +from typing import TYPE_CHECKING, Any |
15 | 15 |
|
16 | 16 | import torch |
17 | 17 |
|
@@ -99,8 +99,70 @@ def vectorized_delta_hyperbolicity( |
99 | 99 |
|
100 | 100 | out = torch.minimum(XY_w_xy, XY_w_yz) |
101 | 101 |
|
102 | | - delta = out - XY_w_xz if full else (out - XY_w_xz).max().item() |
103 | | - |
104 | | - delta = 2 * delta / torch.max(D) if relative else delta |
| 102 | + if full: |
| 103 | + delta = out - XY_w_xz |
| 104 | + if relative: |
| 105 | + max_dist = torch.max(D) |
| 106 | + if max_dist > 0: |
| 107 | + delta = 2 * delta / max_dist |
| 108 | + else: |
| 109 | + delta = torch.zeros_like(delta) |
| 110 | + else: |
| 111 | + delta = (out - XY_w_xz).max().item() |
| 112 | + if relative: |
| 113 | + max_dist = torch.max(D).item() |
| 114 | + if max_dist > 0: |
| 115 | + delta = 2 * delta / max_dist |
| 116 | + else: |
| 117 | + delta = 0.0 |
105 | 118 |
|
106 | 119 | return delta |
| 120 | + |
| 121 | + |
| 122 | +def delta_hyperbolicity( |
| 123 | + distance_matrix: Float[torch.Tensor, "n_points n_points"], |
| 124 | + method: str = "global", |
| 125 | + **kwargs: Any |
| 126 | +) -> Float[torch.Tensor, "n_points"] | float: |
| 127 | + r"""Computes the δ-hyperbolicity from a distance matrix. |
| 128 | +
|
| 129 | + This function implements δ-hyperbolicity computation, which measures how close a metric |
| 130 | + space is to a tree. The value δ ≥ 0 is a global property; smaller values indicate |
| 131 | + the space is more hyperbolic (tree-like). |
| 132 | +
|
| 133 | + For each triplet of points (x,y,z) and reference point w, computes: |
| 134 | + δ(x,y,z) = min((x,y)_w, (y,z)_w) - (x,z)_w |
| 135 | +
|
| 136 | + where (a,b)_w = ½(d(w,a) + d(w,b) - d(a,b)) is the Gromov product. |
| 137 | +
|
| 138 | + Args: |
| 139 | + distance_matrix: Pairwise distance matrix as a torch.Tensor. |
| 140 | + method: Computation method. Options: |
| 141 | + - "sampled": Random sampling approach, returns array of δ values for sampled triplets |
| 142 | + - "global": Global maximum δ value over all triplets, returns single scalar |
| 143 | + - "full": Full δ tensor over all triplets, returns tensor of shape (n,n,n) |
| 144 | + **kwargs: Additional arguments passed to the computation function. |
| 145 | + For "sampled": n_samples, reference_idx, relative |
| 146 | + For "global"/"full": reference_idx, relative |
| 147 | +
|
| 148 | + Returns: |
| 149 | + delta_values: δ-hyperbolicity estimates. |
| 150 | + - "sampled": torch.Tensor of shape (n_samples,) |
| 151 | + - "global": float scalar (maximum δ value) |
| 152 | + - "full": torch.Tensor of shape (n_points, n_points, n_points) |
| 153 | + """ |
| 154 | + # Validate input |
| 155 | + if not isinstance(distance_matrix, torch.Tensor): |
| 156 | + raise TypeError(f"distance_matrix must be a torch.Tensor, got {type(distance_matrix)}") |
| 157 | + |
| 158 | + D = distance_matrix.float() |
| 159 | + |
| 160 | + if method == "sampled": |
| 161 | + deltas, _ = sampled_delta_hyperbolicity(D, **kwargs) |
| 162 | + return deltas |
| 163 | + elif method == "global": |
| 164 | + return vectorized_delta_hyperbolicity(D, full=False, **kwargs) |
| 165 | + elif method == "full": |
| 166 | + return vectorized_delta_hyperbolicity(D, full=True, **kwargs) |
| 167 | + else: |
| 168 | + raise ValueError(f"Unknown method: {method}. Choose 'sampled', 'global', 'full'") |
0 commit comments