Skip to content

Commit b31df35

Browse files
committed
Add embedders documentation
1 parent bfb8007 commit b31df35

File tree

6 files changed

+352
-131
lines changed

6 files changed

+352
-131
lines changed

manify/embedders/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
"""Tools for embedding data into Riemannian manifolds and product spaces.
2+
3+
The embedders module provides various ways to embed data into manifolds of constant
4+
or mixed curvature. The module includes:
5+
6+
* `coordinate_learning`: Direct optimization of coordinates in a product manifold.
7+
* `siamese`: Siamese network-based embedding for metric learning.
8+
* `vae`: Variational autoencoders for learning representations in product manifolds.
9+
* `_losses`: Loss functions for measuring embedding quality.
10+
"""
11+
112
import manify.embedders.coordinate_learning
213
import manify.embedders.siamese
314
import manify.embedders.vae

manify/embedders/_losses.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
"""Implementation of different measurement metrics"""
1+
"""Implementation of metrics and loss functions for evaluating embedding quality.
2+
3+
This module provides various functions to measure the quality of embeddings
4+
in Riemannian manifolds, including distortion metrics, average distance error,
5+
and other evaluation measures for both graph and general embedding tasks.
6+
"""
27

38
from __future__ import annotations
49

@@ -16,19 +21,26 @@ def distortion_loss(
1621
D_true: Float[torch.Tensor, "n_points n_points"],
1722
pairwise: bool = False,
1823
) -> Float[torch.Tensor, ""]:
19-
"""Compute the distortion loss between estimated SQUARED distances and true SQUARED distances.
24+
r"""Computes the distortion loss between estimated and true squared distances.
25+
26+
The distortion loss measures how well the pairwise distances in the embedding space match the true distances. It is
27+
calculated as
28+
29+
$$\sum_{i,j} \left(\left(\frac{D_{\text{est}}(i,j)}{D_{\text{true}}(i,j)}\right)^2 - 1\right),$$
30+
31+
where the sum is over all pairs of points (or just unique pairs if `pairwise=True`).
2032
2133
Args:
22-
D_est: A tensor of estimated pairwise distances.
23-
D_true: A tensor of true pairwise distances.
24-
pairwise: A boolean indicating whether to return whether D_est and D_true are pairwise
34+
D_est: Tensor of estimated pairwise squared distances.
35+
D_true: Tensor of true pairwise squared distances.
36+
pairwise: Whether to consider only unique pairs (upper triangular part of the matrices). Defaults to False.
2537
2638
Returns:
27-
float: A float indicating the distortion loss, calculated as the sum of the squared relative
28-
errors between the estimated and true squared distances.
39+
loss: Scalar tensor representing the distortion loss.
2940
30-
See also: `square_loss` in HazyResearch hyperbolics repo:
31-
https://github.com/HazyResearch/hyperbolics/blob/master/pytorch/hyperbolic_models.py#L178
41+
Note:
42+
This is similar to the `square_loss` in HazyResearch hyperbolics repository:
43+
https://github.com/HazyResearch/hyperbolics/blob/master/pytorch/hyperbolic_models.py#L178
3244
"""
3345

3446
# Turn into flat vectors of pairwise distances. For pairwise distances, we only consider the upper triangle.
@@ -54,15 +66,22 @@ def d_avg(
5466
D_true: Float[torch.Tensor, "n_points n_points"],
5567
pairwise: bool = False,
5668
) -> Float[torch.Tensor, ""]:
57-
"""Average distance error D_av
69+
r"""Computes the average relative distance error (D_avg).
70+
71+
The average distance error is the mean relative error between the estimated and true distances:
72+
73+
$$D_{\text{avg}} = \frac{1}{N} \sum_{i,j} \frac{|D_{\text{est}}(i,j) - D_{\text{true}}(i,j)|}{D_{\text{true}}(i,j)},$$
74+
75+
where $N$ is the number of distances being considered. This metric provides a normalized measure of how accurately
76+
the embedding preserves the original distances.
77+
5878
Args:
59-
D_est (n_points, n_points): A tensor of estimated pairwise distances.
60-
D_true (n_points, n_points).: A tensor of true pairwise distances.
61-
pairwise (bool): A boolean indicating whether to return whether D_est and D_true are pairwise
79+
D_est: Tensor of estimated pairwise distances.
80+
D_true: Tensor of true pairwise distances.
81+
pairwise: Whether to consider only unique pairs (upper triangular part of the matrices). Defaults to False.
6282
6383
Returns:
64-
float: A float indicating the average distance error D_avg, calculated as the
65-
mean relative error across all pairwise distances.
84+
d_avg: Scalar tensor representing the average relative distance error.
6685
"""
6786

6887
if pairwise:
@@ -84,22 +103,41 @@ def d_avg(
84103

85104

86105
def mean_average_precision(x_embed: Float[torch.Tensor, "n_points n_dim"], graph: nx.Graph) -> Float[torch.Tensor, ""]:
87-
"""Mean averae precision (mAP) from the Gu et al paper."""
106+
r"""Computes the mean average precision (mAP) for graph embedding evaluation.
107+
108+
This metric is used to evaluate how well an embedding preserves the neighborhood structure of a graph, as described
109+
in Gu et al. (2019): "Learning Mixed-Curvature Representations in Product Spaces".
110+
111+
Args:
112+
x_embed: Tensor containing the embeddings of the graph nodes.
113+
graph: NetworkX graph representing the original graph structure.
114+
115+
Returns:
116+
mAP: Mean average precision score.
117+
118+
Note:
119+
This function is currently not implemented.
120+
"""
88121
raise NotImplementedError
89122

90123

91124
def dist_component_by_manifold(pm: ProductManifold, x_embed: Float[torch.Tensor, "n_points n_dim"]) -> List[float]:
92-
"""
93-
Compute the variance in pairwise distances explained by each manifold component.
125+
r"""Computes the proportion of variance in pairwise distances explained by each manifold component.
126+
127+
The contribution is calculated as the ratio of the sum of squared distances in each component to the total squared
128+
distance:
129+
130+
$$\text{contribution}_k = \frac{\sum_{i<j} D^2_k(x_i, x_j)}{\sum_{i<j} D^2_{\text{total}}(x_i, x_j)}$$
131+
132+
where $D^2_k$ is the squared distance in the $k$-th manifold component.
94133
95134
Args:
96-
pm: The product manifold.
97-
x_embed (n_points, n_dim): A tensor of embeddings.
135+
pm: The product manifold containing multiple component manifolds.
136+
x_embed: Tensor of embeddings in the product manifold.
98137
99138
Returns:
100-
List[float]: A list of proportions, where each value represents the fraction
101-
of total distance variance explained by the corresponding
102-
manifold component.
139+
contributions: List of proportions, where each value represents the fraction of total distance variance
140+
explained by the corresponding manifold component.
103141
"""
104142
sq_dists_by_manifold = [M.pdist2(x_embed[:, pm.man2dim[i]]) for i, M in enumerate(pm.P)]
105143
total_sq_dist = pm.pdist2(x_embed)

manify/embedders/coordinate_learning.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
"""Implementation for coordinate training and optimization"""
1+
"""Implementation for direct coordinate optimization in Riemannian manifolds.
2+
3+
This module provides functions for learning optimal embeddings in product manifolds by directly optimizing the
4+
coordinates using Riemannian optimization. This approach is particularly useful for embedding graphs using metric learning
5+
to maintain pairwise distances in the target space. The optimization is performed using Riemannian gradient descent
6+
with support for non-transductive training, in which gradients from the test set to the training set are masked out.
7+
"""
28

39
from __future__ import annotations
410

@@ -33,25 +39,44 @@ def train_coords(
3339
loss_window_size: int = 100,
3440
logging_interval: int = 10,
3541
) -> Tuple[Float[torch.Tensor, "n_points n_dim"], Dict[str, List[float]]]:
36-
"""
37-
Coordinate training and optimization
42+
r"""Trains point coordinates in a product manifold to match target distances.
43+
44+
This function optimizes the coordinates of points in a product manifold to match a given distance matrix. The
45+
optimization is performed in two phases:
46+
47+
1. Burn-in phase: Initial optimization with a smaller learning rate to find a good starting configuration.
48+
2. Training phase: Fine-tuning of the coordinates with a larger learning rate, and optionally optimizing the scale
49+
factors (curvatures) of the manifold components.
50+
51+
The optimization uses Riemannian Adam optimizer to respect the manifold structure during gradient updates. The loss
52+
is computed based on the distortion between the pairwise distances in the embedding and the target distances.
53+
54+
For non-transductive settings, the function supports split between training and testing points, optimizing different
55+
combinations of distances (train-train, test-test, train-test).
3856
3957
Args:
40-
pm: ProductManifold object that encapsulates the manifold and its signature.
41-
dists: (n_points, n_points) Tensor representing the pairwise distance matrix between points.
42-
test_indices: (n_test) Tensor representing the indices of the test points.
43-
device: Device for training (default: "cpu").
44-
burn_in_learning_rate: Learning rate during the burn-in phase (default: 1e-3).
45-
burn_in_iterations: Number of iterations during the burn-in phase (default: 2,000).
46-
learning_rate: Learning rate during the training phase (default: 1e-2).
47-
scale_factor_learning_rate: Learning rate for scale factor optimization (default: 0.0).
48-
training_iterations: Number of iterations for the training phase (default: 18,000).
49-
loss_window_size: Window size for computing the moving average of the loss (default: 100).
50-
logging_interval: Interval for logging the training progress (default: 10).
58+
pm: ProductManifold object specifying the target manifold structure.
59+
dists: Tensor representing the target pairwise distance matrix between points.
60+
test_indices: Tensor containing indices of test points for transductive learning.
61+
Defaults to an empty tensor (all points are used for training).
62+
device: Device for tensor computations. Defaults to "cpu".
63+
burn_in_learning_rate: Learning rate for the burn-in phase. Defaults to 1e-3.
64+
burn_in_iterations: Number of iterations for the burn-in phase. Defaults to 2,000.
65+
learning_rate: Learning rate for the main training phase. Defaults to 1e-2.
66+
scale_factor_learning_rate: Learning rate for optimizing manifold scale factors.
67+
Defaults to 0.0 (no optimization of curvatures).
68+
training_iterations: Number of iterations for the main training phase. Defaults to 18,000.
69+
loss_window_size: Window size for computing moving average loss. Defaults to 100.
70+
logging_interval: Interval for logging training progress. Defaults to 10.
5171
5272
Returns:
53-
pm.x_embed: Tensor of the final learned coordinates in the manifold.
54-
losses: List of loss values at each iteration during training.
73+
embeddings: Tensor of shape (n_points, n_dim) with optimized coordinates in the manifold.
74+
losses: Dictionary containing loss histories for different components:
75+
76+
* 'train_train': Loss between training points
77+
* 'test_test': Loss between test points (if test_indices is provided)
78+
* 'train_test': Loss between training and test points (if test_indices is provided)
79+
* 'total': Sum of all loss components
5580
"""
5681
# Move everything to the device
5782
n = dists.shape[0]

manify/embedders/siamese.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
"""Siamese network embedder"""
1+
"""Siamese network implementation for manifold embedding.
2+
3+
This module provides a Siamese network architecture that can be used for embedding data into product manifolds. Siamese
4+
networks are particularly useful for metric learning tasks, where the goal is to learn a distance-preserving embedding,
5+
while also encoding a set of features.
6+
7+
The SiameseNetwork class supports both encoding (embedding) data into a manifold space and optionally decoding
8+
(reconstructing) from the embedding space back to the original data space.
9+
"""
210

311
from __future__ import annotations
412

@@ -11,6 +19,29 @@
1119

1220

1321
class SiameseNetwork(torch.nn.Module):
22+
"""Siamese network for embedding data into a product manifold space.
23+
24+
A Siamese network consists of an encoder network that maps input data to a latent representation in a product
25+
manifold, and optionally a decoder network that maps the latent representation back to the original feature space.
26+
27+
Attributes:
28+
pm: The product manifold object defining the embedding space.
29+
encoder: Neural network module that maps input data to the embedding space.
30+
decoder: Optional neural network module for reconstructing input data from embeddings.
31+
reconstruction_loss: Loss function for measuring reconstruction quality.
32+
33+
Args:
34+
pm: Product manifold object defining the target embedding space.
35+
encoder: Neural network module that maps inputs to the embedding space.
36+
decoder: Optional neural network module that maps embeddings back to input space.
37+
If None, a no-op identity module is used. Defaults to None.
38+
reconstruction_loss: Type of reconstruction loss to use.
39+
Currently only "mse" (mean squared error) is supported. Defaults to "mse".
40+
41+
Raises:
42+
ValueError: If an unsupported reconstruction_loss is specified.
43+
"""
44+
1445
def __init__(
1546
self,
1647
pm: ProductManifold,
@@ -35,23 +66,28 @@ def __init__(
3566
raise ValueError(f"Unknown reconstruction loss: {reconstruction_loss}")
3667

3768
def encode(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Float[torch.Tensor, "batch_size n_latent"]:
38-
"""Encodes the input tensor into a latent representation.
69+
"""Encodes input data into the manifold embedding space.
70+
71+
Takes a batch of input data and passes it through the encoder network to obtain embeddings in the manifold.
3972
4073
Args:
41-
x (TensorType["batch_size", "n_features"]): The input tensor.
74+
x: Input data tensor..
4275
4376
Returns:
44-
TensorType["batch_size", "n_latent"]: The encoded latent representation.
77+
embeddings: Tensor containing the embeddings in the manifold space.
4578
"""
4679
return self.encoder(x)
4780

4881
def decode(self, z: Float[torch.Tensor, "batch_size n_latent"]) -> Float[torch.Tensor, "batch_size n_features"]:
49-
"""Decodes the latent representation back to the input space.
82+
"""Decodes manifold embeddings back to the original input space.
83+
84+
Takes a batch of embeddings from the manifold space and passes them through
85+
the decoder network to reconstruct the original input data.
5086
5187
Args:
52-
z (TensorType["batch_size", "n_latent"]): The latent representation.
88+
z: Embedding tensor from the manifold space.
5389
5490
Returns:
55-
TensorType["batch_size", "n_features"]: The reconstructed input tensor.
91+
reconstructed: Tensor containing the reconstructed input data.
5692
"""
5793
return self.decoder(z)

0 commit comments

Comments
 (0)