Skip to content

Commit 9886838

Browse files
committed
Change __init__ structure; fix embedder tests (hopefully)
1 parent ffc71aa commit 9886838

File tree

9 files changed

+136
-25
lines changed

9 files changed

+136
-25
lines changed

manify/__init__.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,42 @@
11
"""Manify: A Python Library for Learning Non-Euclidean Representations."""
22

3-
import manify.curvature_estimation
4-
import manify.embedders
5-
import manify.manifolds
6-
import manify.predictors
7-
import manify.utils
3+
from manify.curvature_estimation import (
4+
sampled_delta_hyperbolicity,
5+
delta_hyperbolicity,
6+
sectional_curvature,
7+
greedy_signature_selection,
8+
)
9+
from manify.embedders import CoordinateLearning, ProductSpaceVAE, SiameseNetwork
10+
from manify.manifolds import Manifold, ProductManifold
11+
from manify.predictors import ProductSpaceDT, ProductSpaceRF, KappaGCN, ProductSpacePerceptron, ProductSpaceSVM
12+
13+
# import manify.utils
814

915
# Define version and other package metadata
1016
__version__ = "0.0.2"
1117
__author__ = "Philippe Chlenski"
1218
__email__ = "pac@cs.columbia.edu"
1319
__license__ = "MIT"
20+
21+
# Export modules
22+
__all__ = [
23+
# manify.manifolds
24+
"Manifold",
25+
"ProductManifold",
26+
# manify.embedders
27+
"CoordinateLearning",
28+
"ProductSpaceVAE",
29+
"SiameseNetwork",
30+
# manify.predictors
31+
"ProductSpaceDT",
32+
"ProductSpaceRF",
33+
"KappaGCN",
34+
"ProductSpacePerceptron",
35+
"ProductSpaceSVM",
36+
# manify.curvature_estimation
37+
"delta_hyperbolicity",
38+
"sampled_delta_hyperbolicity",
39+
"sectional_curvature",
40+
"greedy_signature_selection",
41+
# no utils
42+
]

manify/curvature_estimation/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
* `sectional_curvature`: Estimates the sectional curvature of a graph from its distance matrix.
88
"""
99

10-
import manify.curvature_estimation.delta_hyperbolicity
11-
import manify.curvature_estimation.greedy_method
12-
import manify.curvature_estimation.sectional_curvature
10+
from manify.curvature_estimation.delta_hyperbolicity import sampled_delta_hyperbolicity, delta_hyperbolicity
11+
from manify.curvature_estimation.greedy_method import greedy_signature_selection
12+
from manify.curvature_estimation.sectional_curvature import sectional_curvature
13+
14+
__all__ = ["greedy_signature_selection", "sectional_curvature", "sampled_delta_hyperbolicity", "delta_hyperbolicity"]

manify/curvature_estimation/greedy_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..manifolds import ProductManifold
1414

1515

16-
def greedy_curvature_method(
16+
def greedy_signature_selection(
1717
pm: ProductManifold,
1818
dists: torch.Tensor,
1919
candidate_components: Tuple[Tuple[float, int], ...] = ((-1.0, 2), (0.0, 2), (1.0, 2)),

manify/embedders/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
* `siamese`: Siamese network-based embedding for metric learning.
88
* `vae`: Variational autoencoders for learning representations in product manifolds.
99
* `_losses`: Loss functions for measuring embedding quality.
10+
* `_base`: Base class for embedders.
1011
"""
1112

12-
import manify.embedders.coordinate_learning
13-
import manify.embedders.siamese
14-
import manify.embedders.vae
13+
from manify.embedders.coordinate_learning import CoordinateLearning
14+
from manify.embedders.siamese import SiameseNetwork
15+
from manify.embedders.vae import ProductSpaceVAE
16+
17+
__all__ = ["CoordinateLearning", "SiameseNetwork", "ProductSpaceVAE"]

manify/embedders/vae.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ def fit( # type: ignore[override]
262262
self,
263263
X: Float[torch.Tensor, "n_points n_features"],
264264
D: None = None,
265-
lr: float = 1e-2,
266-
burn_in_lr: float = 1e-3,
265+
lr: float = 1e-3,
266+
burn_in_lr: float = 1e-4,
267267
curvature_lr: float = 0.0, # Off by default
268268
burn_in_iterations: int = 1,
269269
training_iterations: int = 9,
@@ -303,7 +303,7 @@ def fit( # type: ignore[override]
303303

304304
my_tqdm = tqdm(total=(burn_in_iterations + training_iterations) * len(X))
305305
opt = torch.optim.Adam(
306-
[{"params": self.parameters(), "lr": lr * 0.1}, {"params": self.pm.parameters(), "lr": curvature_lr}]
306+
[{"params": self.parameters(), "lr": burn_in_lr}, {"params": self.pm.parameters(), "lr": 0}]
307307
)
308308
losses: Dict[str, List[float]] = {"elbo": [], "ll": [], "kl": []}
309309
for epoch in range(burn_in_iterations + training_iterations):
@@ -352,13 +352,14 @@ def fit( # type: ignore[override]
352352
return self
353353

354354
def transform(
355-
self, X: Float["n_points n_features"], D: None = None, batch_size: int = 32
355+
self, X: Float["n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
356356
) -> Float["n_points embedding_dim"]:
357357
"""Transform data using the trained VAE. Outputs means of the variational distribution.
358358
359359
Args:
360360
X: Features to embed with VAE.
361361
D: Ignored.
362+
expmap: Whether to use exponential map for embedding.
362363
363364
Returns:
364365
embeddings: Learned embeddings.
@@ -372,7 +373,12 @@ def transform(
372373
embeddings_list = []
373374
for i in range(0, len(X), batch_size):
374375
x_batch = X[i : i + batch_size]
375-
z_mean, _ = self.encode(x_batch)
376+
z_mean_tangent, _ = self.encode(x_batch)
377+
if expmap:
378+
z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix # Adds zeros in the right places
379+
z_mean = self.pm.expmap(u=z_mean_ambient, base=None)
380+
else:
381+
z_mean = z_mean_tangent
376382
embeddings_list.append(z_mean.detach().cpu())
377383

378384
embeddings = torch.cat(embeddings_list, dim=0)

manify/predictors/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,16 @@
181181
182182
where $\varepsilon > 0$ and $\xi_i \geq 0$.
183183
"""
184-
import manify.predictors.decision_tree
185-
import manify.predictors.kappa_gcn
186-
import manify.predictors.perceptron
187-
import manify.predictors.svm
184+
185+
from manify.predictors.decision_tree import ProductSpaceDT, ProductSpaceRF
186+
from manify.predictors.kappa_gcn import KappaGCN
187+
from manify.predictors.perceptron import ProductSpacePerceptron
188+
from manify.predictors.svm import ProductSpaceSVM
189+
190+
__all__ = [
191+
"ProductSpaceDT",
192+
"ProductSpaceRF",
193+
"KappaGCN",
194+
"ProductSpacePerceptron",
195+
"ProductSpaceSVM",
196+
]

manify/utils/benchmarks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from sklearn.base import BaseEstimator
1212
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
1313
from sklearn.linear_model import SGDClassifier, SGDRegressor
14-
from sklearn.metrics import (accuracy_score, f1_score, mean_squared_error,
15-
root_mean_squared_error)
14+
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, root_mean_squared_error
1615
from sklearn.model_selection import train_test_split
1716
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
1817
from sklearn.svm import SVC, SVR

0 commit comments

Comments
 (0)