Skip to content

Commit 6052f76

Browse files
committed
Standardize RFK syntax; add curvature estimation + Kappa MLP + clustering to tutorial
1 parent fd166f6 commit 6052f76

File tree

5 files changed

+934
-115
lines changed

5 files changed

+934
-115
lines changed

manify/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
install_import_hook("manify", "beartype.beartype")
99
print("Beartype import hook installed for Manify. Will use beartype for type checking.")
1010

11+
from manify.clustering import RiemannianFuzzyKMeans
1112
from manify.curvature_estimation import greedy_signature_selection, sampled_delta_hyperbolicity, sectional_curvature
1213
from manify.embedders import CoordinateLearning, ProductSpaceVAE, SiameseNetwork
1314
from manify.manifolds import Manifold, ProductManifold
@@ -41,5 +42,7 @@
4142
"sampled_delta_hyperbolicity",
4243
"sectional_curvature",
4344
"greedy_signature_selection",
45+
# manify.clustering
46+
"RiemannianFuzzyKMeans",
4447
# no utils
4548
]

manify/clustering/fuzzy_kmeans.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class RiemannianFuzzyKMeans(BaseEstimator, ClusterMixin):
4343
4444
Attributes:
4545
n_clusters: The number of clusters to form.
46-
manifold: An initialized manifold object (from manifolds.py) on which clustering will be performed.
46+
pm: An initialized manifold object (from manifolds.py) on which clustering will be performed.
4747
m: Fuzzifier parameter. Controls the softness of the partition.
4848
lr: Learning rate for the optimizer.
4949
max_iter: Maximum number of iterations for the optimization.
@@ -71,7 +71,7 @@ class RiemannianFuzzyKMeans(BaseEstimator, ClusterMixin):
7171
def __init__(
7272
self,
7373
n_clusters: int,
74-
manifold: Manifold | ProductManifold,
74+
pm: Manifold | ProductManifold,
7575
m: float = 2.0,
7676
lr: float = 0.1,
7777
max_iter: int = 100,
@@ -81,7 +81,7 @@ def __init__(
8181
verbose: bool = False,
8282
):
8383
self.n_clusters = n_clusters
84-
self.manifold = manifold
84+
self.pm = pm
8585
self.m = m
8686
self.lr = lr
8787
self.max_iter = max_iter
@@ -97,11 +97,11 @@ def _init_centers(self, X: Float[torch.Tensor, "n_points n_features"]) -> None:
9797
torch.manual_seed(self.random_state)
9898
np.random.seed(self.random_state)
9999

100-
# Input data X's second dimension should match the manifold's ambient dimension
101-
if X.shape[1] != self.manifold.ambient_dim:
100+
# Input data X's second dimension should match the pm's ambient dimension
101+
if X.shape[1] != self.pm.ambient_dim:
102102
raise ValueError(
103103
f"Input data X's dimension ({X.shape[1]}) does not match "
104-
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
104+
f"the manifold's ambient dimension ({self.pm.ambient_dim})."
105105
)
106106

107107
# Generate initial centers using the manifold's sample method
@@ -112,16 +112,15 @@ def _init_centers(self, X: Float[torch.Tensor, "n_points n_features"]) -> None:
112112

113113
# For sampling initial centers, we want n_clusters distinct points.
114114
# The .sample() method typically takes a z_mean of shape (num_points_to_sample, ambient_dim).
115-
# If we provide self.manifold.mu0 repeated n_clusters times,
115+
# If we provide self.pm.mu0 repeated n_clusters times,
116116
# it samples n_clusters points, each around mu0.
117-
means_for_sampling_centers = self.manifold.mu0.repeat(self.n_clusters, 1)
118-
centers = self.manifold.sample(z_mean=means_for_sampling_centers)
117+
centers = self.pm.sample(self.n_clusters)
119118

120119
# IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121120
# as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.
122121
self.mu_ = ManifoldParameter(
123122
centers.clone().detach(), # type: ignore
124-
manifold=self.manifold.manifold,
123+
manifold=self.pm.manifold,
125124
) # Ensure centers are detached
126125
self.mu_.requires_grad_(True)
127126

@@ -150,22 +149,22 @@ def fit(self, X: Float[torch.Tensor, "n_points n_features"], y: None = None) ->
150149
X = torch.tensor(X, dtype=torch.get_default_dtype())
151150

152151
# Ensure X is on the same device as the manifold
153-
X = X.to(self.manifold.device)
152+
X = X.to(self.pm.device)
154153

155-
if X.shape[1] != self.manifold.ambient_dim:
154+
if X.shape[1] != self.pm.ambient_dim:
156155
raise ValueError(
157156
f"Input data X's dimension ({X.shape[1]}) in fit() does not match "
158-
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
157+
f"the manifold's ambient dimension ({self.pm.ambient_dim})."
159158
)
160159

161160
self._init_centers(X)
162161
m, tol = self.m, self.tol
163162
losses = []
164163
for i in range(self.max_iter):
165164
self.opt_.zero_grad()
166-
# self.manifold.dist is implemented in manifolds.py and handles broadcasting
167-
d = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> d is (N,K)
168-
# Original RFK: d = self.manifold.dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
165+
# self.pm.dist is implemented in manifolds.py and handles broadcasting
166+
d = self.pm.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> d is (N,K)
167+
# Original RFK: d = self.pm.dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
169168
# The .dist in manifolds.py uses X[:, None] and Y[None, :], so direct call should work if mu_ is (K,D)
170169

171170
S = torch.sum(d.pow(-2 / (m - 1)) + 1e-8, dim=1) # Add epsilon for stability
@@ -181,7 +180,7 @@ def fit(self, X: Float[torch.Tensor, "n_points n_features"], y: None = None) ->
181180
# save the result
182181
self.losses_ = np.array(losses)
183182
with torch.no_grad(): # Ensure no gradients are computed for final calculations
184-
dfin = self.manifold.dist(X, self.mu_) # Re-calculate dist to final centers
183+
dfin = self.pm.dist(X, self.mu_) # Re-calculate dist to final centers
185184
inv = dfin.pow(-2 / (m - 1)) + 1e-8 # Add epsilon
186185
u_final = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
187186
self.u_ = u_final.detach().cpu().numpy()
@@ -208,19 +207,19 @@ def predict(self, X: Float[torch.Tensor, "n_points n_features"]) -> Int[torch.Te
208207
X = torch.tensor(X, dtype=torch.get_default_dtype())
209208

210209
# Ensure X is on the same device as the manifold
211-
X = X.to(self.manifold.device)
210+
X = X.to(self.pm.device)
212211

213-
if X.shape[1] != self.manifold.ambient_dim:
212+
if X.shape[1] != self.pm.ambient_dim:
214213
raise ValueError(
215214
f"Input data X's dimension ({X.shape[1]}) in predict() does not match "
216-
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
215+
f"the manifold's ambient dimension ({self.pm.ambient_dim})."
217216
)
218217

219218
if not hasattr(self, "mu_") or self.mu_ is None:
220219
raise RuntimeError("The RFK model has not been fitted yet. Call 'fit' before 'predict'.")
221220

222221
with torch.no_grad():
223-
dmat = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
222+
dmat = self.pm.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
224223
inv = dmat.pow(-2 / (self.m - 1)) + 1e-8 # Add epsilon
225224
u = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
226225
labels = torch.argmax(u, dim=1).cpu().numpy()

manify/curvature_estimation/greedy_method.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def greedy_signature_selection(
2020
candidate_components: Iterable[tuple[float, int]] = ((-1.0, 2), (0.0, 2), (1.0, 2)),
2121
max_components: int = 3,
2222
pipeline: Callable[..., float] = distortion_pipeline,
23+
verbose: bool = False,
2324
**kwargs: dict[str, Any],
2425
) -> tuple[ProductManifold, list[float]]:
2526
r"""Greedily estimates an optimal product manifold signature.
@@ -32,6 +33,7 @@ def greedy_signature_selection(
3233
candidate_components: Candidate (curvature, dimension) pairs to consider.
3334
max_components: Maximum number of components to include.
3435
pipeline: Function that takes a ProductManifold, plus additional arguments, and returns a loss value.
36+
verbose: If True, prints progress information.
3537
**kwargs: Additional keyword arguments to pass to the pipeline function.
3638
3739
Returns:
@@ -44,25 +46,35 @@ def greedy_signature_selection(
4446
candidate_components_list = list(candidate_components) # For type safe iteration
4547

4648
# Greedy loop
47-
for _ in range(max_components):
49+
for i in range(max_components):
50+
if verbose:
51+
print(f"Iteration {i + 1}/{max_components}")
4852
best_loss, best_idx = current_loss, -1
4953

5054
# Try each candidate
5155
for idx, comp in enumerate(candidate_components_list):
52-
pm = ProductManifold(signature=signature + [comp])
56+
if verbose:
57+
print(f" Trying component {comp} (index {idx})")
58+
pm = ProductManifold(signature=signature.copy() + [comp])
5359
loss = pipeline(pm, **kwargs)
5460
if loss < best_loss:
5561
best_loss, best_idx = loss, idx
5662

5763
# If no improvement, stop
5864
if best_idx < 0:
65+
if verbose:
66+
print("No improvement found, stopping.")
5967
break
6068

6169
# Otherwise accept that component
6270
signature.append(candidate_components_list[best_idx])
6371
current_loss = best_loss
6472
loss_history.append(current_loss)
73+
if verbose:
74+
print(f" Accepted component {candidate_components_list[best_idx]} with loss {current_loss:.4f}")
75+
print(f" Current signature: {signature}")
76+
print()
6577

6678
# Return final manifold
67-
optimal_pm = ProductManifold(signature=signature)
79+
optimal_pm = ProductManifold(signature=signature.copy())
6880
return optimal_pm, loss_history

notebooks/Manify Tutorial.ipynb

Lines changed: 895 additions & 90 deletions
Large diffs are not rendered by default.

tests/test_curvature_estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_delta_hyperbolicity():
7575
def test_greedy_method():
7676
# Get a very small subset of the polblogs dataset
7777
_, D, _, y = load_hf("polblogs")
78-
D = D[:128, :128] / D.max()
78+
D = D[:128, :128]
7979
y = y[:128]
8080
D = D / D.max()
8181

0 commit comments

Comments
 (0)