1- """Implementation of different measurement metrics"""
1+ """Implementation of metrics and loss functions for evaluating embedding quality.
2+
3+ This module provides various functions to measure the quality of embeddings
4+ in Riemannian manifolds, including distortion metrics, average distance error,
5+ and other evaluation measures for both graph and general embedding tasks.
6+ """
27
38from __future__ import annotations
49
@@ -16,19 +21,26 @@ def distortion_loss(
1621 D_true : Float [torch .Tensor , "n_points n_points" ],
1722 pairwise : bool = False ,
1823) -> Float [torch .Tensor , "" ]:
19- """Compute the distortion loss between estimated SQUARED distances and true SQUARED distances.
24+ r"""Computes the distortion loss between estimated and true squared distances.
25+
26+ The distortion loss measures how well the pairwise distances in the embedding space match the true distances. It is
27+ calculated as
28+
29+ $$\sum_{i,j} \left(\left(\frac{D_{\text{est}}(i,j)}{D_{\text{true}}(i,j)}\right)^2 - 1\right),$$
30+
31+ where the sum is over all pairs of points (or just unique pairs if `pairwise=True`).
2032
2133 Args:
22- D_est: A tensor of estimated pairwise distances.
23- D_true: A tensor of true pairwise distances.
24- pairwise: A boolean indicating whether to return whether D_est and D_true are pairwise
34+ D_est: Tensor of estimated pairwise squared distances.
35+ D_true: Tensor of true pairwise squared distances.
36+ pairwise: Whether to consider only unique pairs (upper triangular part of the matrices). Defaults to False.
2537
2638 Returns:
27- float: A float indicating the distortion loss, calculated as the sum of the squared relative
28- errors between the estimated and true squared distances.
39+ loss: Scalar tensor representing the distortion loss.
2940
30- See also: `square_loss` in HazyResearch hyperbolics repo:
31- https://github.com/HazyResearch/hyperbolics/blob/master/pytorch/hyperbolic_models.py#L178
41+ Note:
42+ This is similar to the `square_loss` in HazyResearch hyperbolics repository:
43+ https://github.com/HazyResearch/hyperbolics/blob/master/pytorch/hyperbolic_models.py#L178
3244 """
3345
3446 # Turn into flat vectors of pairwise distances. For pairwise distances, we only consider the upper triangle.
@@ -54,15 +66,22 @@ def d_avg(
5466 D_true : Float [torch .Tensor , "n_points n_points" ],
5567 pairwise : bool = False ,
5668) -> Float [torch .Tensor , "" ]:
57- """Average distance error D_av
69+ r"""Computes the average relative distance error (D_avg).
70+
71+ The average distance error is the mean relative error between the estimated and true distances:
72+
73+ $$D_{\text{avg}} = \frac{1}{N} \sum_{i,j} \frac{|D_{\text{est}}(i,j) - D_{\text{true}}(i,j)|}{D_{\text{true}}(i,j)},$$
74+
75+ where $N$ is the number of distances being considered. This metric provides a normalized measure of how accurately
76+ the embedding preserves the original distances.
77+
5878 Args:
59- D_est (n_points, n_points): A tensor of estimated pairwise distances.
60- D_true (n_points, n_points).: A tensor of true pairwise distances.
61- pairwise (bool): A boolean indicating whether to return whether D_est and D_true are pairwise
79+ D_est: Tensor of estimated pairwise distances.
80+ D_true: Tensor of true pairwise distances.
81+ pairwise: Whether to consider only unique pairs (upper triangular part of the matrices). Defaults to False.
6282
6383 Returns:
64- float: A float indicating the average distance error D_avg, calculated as the
65- mean relative error across all pairwise distances.
84+ d_avg: Scalar tensor representing the average relative distance error.
6685 """
6786
6887 if pairwise :
@@ -84,22 +103,41 @@ def d_avg(
84103
85104
86105def mean_average_precision (x_embed : Float [torch .Tensor , "n_points n_dim" ], graph : nx .Graph ) -> Float [torch .Tensor , "" ]:
87- """Mean averae precision (mAP) from the Gu et al paper."""
106+ r"""Computes the mean average precision (mAP) for graph embedding evaluation.
107+
108+ This metric is used to evaluate how well an embedding preserves the neighborhood structure of a graph, as described
109+ in Gu et al. (2019): "Learning Mixed-Curvature Representations in Product Spaces".
110+
111+ Args:
112+ x_embed: Tensor containing the embeddings of the graph nodes.
113+ graph: NetworkX graph representing the original graph structure.
114+
115+ Returns:
116+ mAP: Mean average precision score.
117+
118+ Note:
119+ This function is currently not implemented.
120+ """
88121 raise NotImplementedError
89122
90123
91124def dist_component_by_manifold (pm : ProductManifold , x_embed : Float [torch .Tensor , "n_points n_dim" ]) -> List [float ]:
92- """
93- Compute the variance in pairwise distances explained by each manifold component.
125+ r"""Computes the proportion of variance in pairwise distances explained by each manifold component.
126+
127+ The contribution is calculated as the ratio of the sum of squared distances in each component to the total squared
128+ distance:
129+
130+ $$\text{contribution}_k = \frac{\sum_{i<j} D^2_k(x_i, x_j)}{\sum_{i<j} D^2_{\text{total}}(x_i, x_j)}$$
131+
132+ where $D^2_k$ is the squared distance in the $k$-th manifold component.
94133
95134 Args:
96- pm: The product manifold.
97- x_embed (n_points, n_dim): A tensor of embeddings.
135+ pm: The product manifold containing multiple component manifolds .
136+ x_embed: Tensor of embeddings in the product manifold .
98137
99138 Returns:
100- List[float]: A list of proportions, where each value represents the fraction
101- of total distance variance explained by the corresponding
102- manifold component.
139+ contributions: List of proportions, where each value represents the fraction of total distance variance
140+ explained by the corresponding manifold component.
103141 """
104142 sq_dists_by_manifold = [M .pdist2 (x_embed [:, pm .man2dim [i ]]) for i , M in enumerate (pm .P )]
105143 total_sq_dist = pm .pdist2 (x_embed )
0 commit comments