Skip to content

Commit 8355612

Browse files
committed
Add greedy signature selection
1 parent 4579e2c commit 8355612

File tree

4 files changed

+223
-35
lines changed

4 files changed

+223
-35
lines changed

README.md

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Manify 🪐
2-
> A Python Library for Learning Non-Euclidean Representations
32

43
[![Python Version](https://img.shields.io/badge/python-3.10%2B-blue.svg)](https://www.python.org/downloads/)
5-
[![License](https://img.shields.io/github/license/pchlenski/manify)](https://github.com/pchlenski/manify/blob/main/LICENSE)
64
[![PyPI version](https://badge.fury.io/py/manify.svg)](https://badge.fury.io/py/manify)
75
[![Tests](https://github.com/pchlenski/manify/actions/workflows/test.yml/badge.svg)](https://github.com/pchlenski/manify/actions/workflows/test.yml)
86
[![codecov](https://codecov.io/gh/pchlenski/manify/branch/main/graph/badge.svg)](https://codecov.io/gh/pchlenski/manify)
97
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
8+
[![Documentation](https://img.shields.io/badge/docs-manify.readthedocs.io-blue)](https://manify.readthedocs.io)
9+
[![arXiv](https://img.shields.io/badge/arXiv-2503.09576-b31b1b.svg)](https://arxiv.org/abs/2503.09576)
10+
[![License](https://img.shields.io/github/license/pchlenski/manify)](https://github.com/pchlenski/manify/blob/main/LICENSE)
1011

1112
Manify is a Python library for non-Euclidean representation learning.
1213
It is built on top of `geoopt` and follows `scikit-learn` API conventions.
@@ -18,12 +19,6 @@ The library supports a variety of workflows involving (products of) Riemannian m
1819
perceptrons, and neural networks.
1920
- Clustering manifold-valued data using Riemannian fuzzy K-Means
2021

21-
📖 **Documentation**: [manify.readthedocs.io](https://manify.readthedocs.io)
22-
23-
📝 **Manuscript**: [Manify: A Python Library for Learning Non-Euclidean Representations](https://arxiv.org/abs/2503.09576)
24-
25-
🐛 **Issue Tracker**: [Github](https://github.com/pchlenski/manify/issues)
26-
2722
## Installation
2823

2924
There are two ways to install `manify`:
@@ -41,29 +36,25 @@ There are two ways to install `manify`:
4136
## Quick Example
4237

4338
```python
44-
import torch
45-
from manify.manifolds import ProductManifold
46-
from manify.embedders import CoordinateLearning
47-
from manify.predictors.decision_tree import ProductSpaceDT
39+
import manify
4840
from manify.utils.dataloaders import load_hf
4941
from sklearn.model_selection import train_test_split
5042

51-
# Load graph data
43+
# Load Polblogs graph from HuggingFace
5244
features, dists, adj, labels = load_hf("polblogs")
5345

54-
# Create product manifold
55-
pm = ProductManifold(signature=[(1, 4)]) # S^4_1
46+
# Create an S^4 x H^4 product manifold
47+
pm = manify.ProductManifold(signature=[(1.0, 4), (-1.0, 4)])
5648

5749
# Learn embeddings (Gu et al (2018) method)
58-
embedder = CoordinateLearning(pm=pm)
59-
embedder.fit(X=None, D=dists)
60-
X_embedded = embedder.transform()
50+
embedder = manify.CoordinateLearning(pm=pm)
51+
X_embedded = embedder.fit_transform(X=None, D=dists, burn_in_iterations=200, training_iterations=800)
6152

6253
# Train and evaluate classifier (Chlenski et al (2025) method)
6354
X_train, X_test, y_train, y_test = train_test_split(X_embedded, labels)
64-
tree = ProductSpaceDT(pm=pm, max_depth=3, task="classification")
65-
tree.fit(X_train, y_train)
66-
print(tree.score(X_test, y_test))
55+
model = manify.ProductSpaceDT(pm=pm, max_depth=3, task="classification")
56+
model.fit(X_train, y_train)
57+
print(model.score(X_test, y_test))
6758
```
6859

6960
## Modules
@@ -113,7 +104,7 @@ Decision Trees and Random Forests paper.
113104
Please read our [contributing guide](https://github.com/pchlenski/manify/blob/main/CONTRIBUTING.md) for details on how
114105
to contribute to the project.
115106

116-
## Citation
107+
## References
117108
If you use our work, please cite the `Manify` paper:
118109
```bibtex
119110
@misc{chlenski2025manifypythonlibrarylearning,
@@ -126,3 +117,17 @@ If you use our work, please cite the `Manify` paper:
126117
url={https://arxiv.org/abs/2503.09576},
127118
}
128119
```
120+
121+
Additionally, if you use one of the methods implemented in `manify`, please cite the original papers:
122+
- `CoordinateLearning`: Gu et al. "Learning Mixed-Curvature Representations in Product Spaces." ICLR 2019.
123+
[https://openreview.net/forum?id=HJxeWnCcF7](https://openreview.net/forum?id=HJxeWnCcF7)
124+
- `ProductSpaceVAE`: Skopek et al. "Mixed-Curvature Variational Autoencoders." ICLR 2020.
125+
[https://openreview.net/forum?id=S1g6xeSKDS](https://openreview.net/forum?id=S1g6xeSKDS)
126+
- `SiameseNetwork`: Based on Siamese networks: Chopra et al. "Learning a Similarity Metric Discriminatively, with Application to Face Verification." CVPR 2005. [https://ieeexplore.ieee.org/document/1467314](https://ieeexplore.ieee.org/document/1467314)
127+
- `ProductSpaceDT` and `ProductSpaceRF`: Chlenski et al. "Mixed Curvature Decision Trees and Random Forests." ICML 2025. [https://arxiv.org/abs/2410.13879](https://arxiv.org/abs/2410.13879)
128+
- `KappaGCN`: Bachmann et al. "Constant Curvature Graph Convolutional Networks." ICML 2020. [https://proceedings.mlr.press/v119/bachmann20a.html](https://proceedings.mlr.press/v119/bachmann20a.html)
129+
- `ProductSpacePerceptron` and `ProductSpaceSVM`: Tabaghi et al. "Linear Classifiers in Product Space Forms." ArXiv 2021. [https://arxiv.org/abs/2102.10204](https://arxiv.org/abs/2102.10204)
130+
- `RiemannianFuzzyKMeans` and `RiemannianAdan`: Yuan et al. "Riemannian Fuzzy K-Means." OpenReview 2025. [https://openreview.net/forum?id=9VmOgMN4Ie](https://openreview.net/forum?id=9VmOgMN4Ie)
131+
- Delta-hyperbolicity computation: Based on Gromov's δ-hyperbolicity metric for tree-likeness of metric spaces. Gromov, M. "Hyperbolic Groups." Essays in Group Theory, 1987. [https://link.springer.com/chapter/10.1007/978-1-4613-9586-7_3](https://link.springer.com/chapter/10.1007/978-1-4613-9586-7_3)
132+
- Sectional curvature estimation: Gu et al. "Learning Mixed-Curvature Representations in Product Spaces." ICLR 2019. [https://openreview.net/forum?id=HJxeWnCcF7](https://openreview.net/forum?id=HJxeWnCcF7)
133+
- Greedy signature selection: Tabaghi et al. "Linear Classifiers in Product Space Forms." ArXiv 2021. [https://arxiv.org/abs/2102.10204](https://arxiv.org/abs/2102.10204)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from typing import Any, Literal
7+
8+
from jaxtyping import Float
9+
10+
import torch
11+
from sklearn.model_selection import train_test_split
12+
13+
from ..embedders._losses import distortion_loss
14+
from ..embedders.coordinate_learning import CoordinateLearning
15+
from ..manifolds import ProductManifold
16+
from ..predictors._base import BasePredictor
17+
from ..predictors.decision_tree import ProductSpaceDT
18+
19+
20+
def distortion_pipeline(
21+
pm: ProductManifold,
22+
dists: Float[torch.Tensor, "n_nodes n_nodes"],
23+
embedder_init_kwargs: dict[str, Any] | None = None,
24+
embedder_fit_kwargs: dict[str, Any] | None = None,
25+
) -> float:
26+
"""Builds a distortion‐based pipeline function for greedy signature selection.
27+
28+
Args:
29+
pm: Product manifold to use for the pipeline.
30+
dists: Pairwise distances to approximate.
31+
embedder_init_kwargs: Additional keyword arguments for initializing the embedder model.
32+
embedder_fit_kwargs: Additional keyword arguments for fitting the embedder model.
33+
34+
Returns:
35+
A function f(signature) → loss, where signature is a list
36+
of (curvature, dim) tuples.
37+
"""
38+
if embedder_init_kwargs is None:
39+
embedder_init_kwargs = {}
40+
if embedder_fit_kwargs is None:
41+
embedder_fit_kwargs = {}
42+
43+
dists = dists.to(pm.device)
44+
dists_rescaled = dists / dists.max()
45+
46+
# Initialize embedder model
47+
model = CoordinateLearning(pm=pm, device=pm.device, **embedder_init_kwargs)
48+
49+
# Fit the model
50+
model.fit(X=None, D=dists_rescaled, **embedder_fit_kwargs)
51+
52+
# Loss is the distortion loss of the new embeddings
53+
embeddings = model.embeddings_
54+
new_dists = pm.pdist(X=embeddings)
55+
return float(distortion_loss(new_dists, dists_rescaled).item())
56+
57+
58+
def classifier_pipeline(
59+
pm: ProductManifold,
60+
dists: Float[torch.Tensor, "n_nodes n_nodes"],
61+
labels: Float[torch.Tensor, "n_nodes"],
62+
classifier: type[BasePredictor] = ProductSpaceDT,
63+
task: Literal["classification", "regression"] = "classification",
64+
embedder_init_kwargs: dict[str, Any] | None = None,
65+
embedder_fit_kwargs: dict[str, Any] | None = None,
66+
model_init_kwargs: dict[str, Any] | None = None,
67+
model_fit_kwargs: dict[str, Any] | None = None,
68+
) -> float:
69+
"""Builds a classifier‐based pipeline function for greedy signature selection.
70+
71+
Args:
72+
pm: Product manifold to use for the pipeline.
73+
dists: Pairwise distances to approximate.
74+
labels: Labels for the nodes, used for training the classifier.
75+
classifier: Classifier to use for evaluating the signature.
76+
task: Task type, either "classification" or "regression".
77+
embedder_init_kwargs: Additional keyword arguments for initializing the coordinate learning model.
78+
embedder_fit_kwargs: Additional keyword arguments for fitting the coordinate learning model.
79+
model_init_kwargs: Additional keyword arguments for initializing the classifier.
80+
model_fit_kwargs: Additional keyword arguments for fitting the classifier.
81+
82+
Returns:
83+
The loss of the classifier on the test set after embedding the distances using the product manifold.
84+
"""
85+
if embedder_init_kwargs is None:
86+
embedder_init_kwargs = {}
87+
if embedder_fit_kwargs is None:
88+
embedder_fit_kwargs = {}
89+
if model_init_kwargs is None:
90+
model_init_kwargs = {}
91+
if model_fit_kwargs is None:
92+
model_fit_kwargs = {}
93+
94+
dists = dists.to(pm.device)
95+
dists_rescaled = dists / dists.max()
96+
97+
# Embedding steps
98+
embedder = CoordinateLearning(pm=pm, device=pm.device, **embedder_init_kwargs)
99+
embedder.fit(X=None, D=dists_rescaled, **embedder_fit_kwargs)
100+
X = embedder.embeddings_
101+
102+
# Train-test split
103+
X_train, X_test, y_train, y_test = train_test_split(X, labels)
104+
105+
# Train classifier
106+
model_init_kwargs["task"] = task
107+
model = classifier(pm=pm, **model_init_kwargs)
108+
model.fit(X=X_train, y=y_train, **model_fit_kwargs)
109+
loss = model.score(X=X_test, y=y_test)
110+
111+
# For classification, we want to maximize accuracy; for regression, we minimize MSE.
112+
return -loss if task == "classification" else loss

manify/curvature_estimation/greedy_method.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,61 @@
88

99
from typing import TYPE_CHECKING
1010

11-
import torch
12-
1311
if TYPE_CHECKING:
14-
from jaxtyping import Float
12+
from collections.abc import Callable, Iterable
13+
from typing import Any
1514

1615
from ..manifolds import ProductManifold
16+
from ._pipelines import distortion_pipeline
1717

1818

1919
def greedy_signature_selection(
20-
pm: ProductManifold,
21-
dists: Float[torch.Tensor, "n_points n_points"],
22-
candidate_components: tuple[tuple[float, int], ...] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
20+
candidate_components: Iterable[tuple[float, int]] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
2321
max_components: int = 3,
24-
) -> ProductManifold:
22+
pipeline: Callable[..., float] = distortion_pipeline,
23+
**kwargs: dict[str, Any],
24+
) -> tuple[ProductManifold, list[float]]:
2525
r"""Greedily estimates an optimal product manifold signature.
2626
2727
This implements the greedy signature selection algorithm that incrementally builds a product manifold
2828
by selecting components that best preserve distances. At each step, it chooses the manifold component
2929
that maximizes distortion reduction.
3030
3131
Args:
32-
pm: Initial product manifold to use as starting point.
33-
dists: Pairwise distance matrix to approximate.
3432
candidate_components: Candidate (curvature, dimension) pairs to consider.
3533
max_components: Maximum number of components to include.
34+
pipeline: Function that takes a ProductManifold, plus additional arguments, and returns a loss value.
35+
**kwargs: Additional keyword arguments to pass to the pipeline function.
3636
3737
Returns:
3838
optimal_pm: Optimized product manifold with the selected signature.
39-
40-
Note:
41-
This function is not yet implemented.
4239
"""
43-
raise NotImplementedError
40+
# Initialize variables
41+
signature: list[tuple[float, int]] = []
42+
loss_history: list[float] = []
43+
current_loss = float("inf")
44+
candidate_components_list = list(candidate_components) # For type safe iteration
45+
46+
# Greedy loop
47+
for _ in range(max_components):
48+
best_loss, best_idx = current_loss, -1
49+
50+
# Try each candidate
51+
for idx, comp in enumerate(candidate_components_list):
52+
pm = ProductManifold(signature=signature + [comp])
53+
loss = pipeline(pm, **kwargs)
54+
if loss < best_loss:
55+
best_loss, best_idx = loss, idx
56+
57+
# If no improvement, stop
58+
if best_idx < 0:
59+
break
60+
61+
# Otherwise accept that component
62+
signature.append(candidate_components_list[best_idx])
63+
current_loss = best_loss
64+
loss_history.append(current_loss)
65+
66+
# Return final manifold
67+
optimal_pm = ProductManifold(signature=signature)
68+
return optimal_pm, loss_history

tests/test_curvature_estimation.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import torch
22

3+
from manify.curvature_estimation._pipelines import classifier_pipeline, distortion_pipeline
34
from manify.curvature_estimation.delta_hyperbolicity import sampled_delta_hyperbolicity, vectorized_delta_hyperbolicity
5+
from manify.curvature_estimation.greedy_method import greedy_signature_selection
46
from manify.manifolds import ProductManifold
7+
from manify.utils.dataloaders import load_hf
58

69

710
def iterative_delta_hyperbolicity(D, reference_idx=0, relative=True):
@@ -67,3 +70,46 @@ def test_delta_hyperbolicity():
6770
assert torch.allclose(sampled_deltas, vectorized_deltas[indices[:, 0], indices[:, 1], indices[:, 2]], atol=1e-5), (
6871
"Sampled deltas should be close to vectorized deltas."
6972
)
73+
74+
75+
def test_greedy_method():
76+
# Get a very small subset of the polblogs dataset
77+
_, D, _, y = load_hf("polblogs")
78+
D = D[:128, :128] / D.max()
79+
y = y[:128]
80+
D = D / D.max()
81+
82+
max_components = 3
83+
embedder_init_kwargs = {"random_state": 42}
84+
embedder_fit_kwargs = {"burn_in_iterations": 10, "training_iterations": 90, "lr": 1e-2}
85+
86+
# Try distortion pipeline
87+
optimal_pm, loss_history = greedy_signature_selection(
88+
pipeline=distortion_pipeline,
89+
dists=D,
90+
embedder_init_kwargs=embedder_init_kwargs,
91+
embedder_fit_kwargs=embedder_fit_kwargs,
92+
)
93+
# assert set(optimal_pm.signature) == set(pm.signature), "Optimal signature should match the initial signature"
94+
assert len(optimal_pm.signature) == len(loss_history)
95+
assert len(optimal_pm.signature) <= max_components
96+
assert len(optimal_pm.signature) > 0, "Optimal signature should not be empty"
97+
assert len(loss_history) > 0, "Loss history should not be empty"
98+
if len(loss_history) > 1:
99+
assert loss_history[-1] < loss_history[0], "Loss should decrease over iterations"
100+
101+
# Try classifier pipeline
102+
optimal_pm, loss_history = greedy_signature_selection(
103+
pipeline=classifier_pipeline,
104+
labels=y,
105+
dists=D,
106+
embedder_init_kwargs=embedder_init_kwargs,
107+
embedder_fit_kwargs=embedder_fit_kwargs,
108+
)
109+
# assert set(optimal_pm.signature) == set(pm.signature), "Optimal signature should match the initial signature"
110+
assert len(optimal_pm.signature) == len(loss_history)
111+
assert len(optimal_pm.signature) <= max_components
112+
assert len(optimal_pm.signature) > 0, "Optimal signature should not be empty"
113+
assert len(loss_history) > 0, "Loss history should not be empty"
114+
if len(loss_history) > 1:
115+
assert loss_history[-1] < loss_history[0], "Loss should decrease over iterations"

0 commit comments

Comments
 (0)