Skip to content

Commit 8e4b8aa

Browse files
committed
Partial fix to RFK docstrings
1 parent a599aa7 commit 8e4b8aa

File tree

6 files changed

+271
-246
lines changed

6 files changed

+271
-246
lines changed

manify/clustering/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
# -*- coding: utf-8 -*-
2-
from manify.clustering.fuzzy_kmeans import RiemannianFuzzyKMeans
1+
"""Clustering algorithms for Riemannian manifolds. Under construction."""
32

4-
__all__ = [
5-
"RiemannianFuzzyKMeans"
6-
]
3+
from manify.clustering.fuzzy_kmeans import RiemannianFuzzyKMeans
74

5+
__all__ = ["RiemannianFuzzyKMeans"]

manify/clustering/fuzzy_kmeans.py

Lines changed: 104 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,93 @@
1-
'''
2-
The Riemannian Fuzzy K-Means algorithm is a clustering algorithm that operates on Riemannian manifolds.
3-
Compared to a straightforward extension of K-Means or Fuzzy K-Means to Riemannian manifolds,
4-
it offers significant acceleration while achieving lower loss. For more details,
1+
"""The Riemannian Fuzzy K-Means algorithm is a clustering algorithm that operates on Riemannian manifolds.
2+
Compared to a straightforward extension of K-Means or Fuzzy K-Means to Riemannian manifolds,
3+
it offers significant acceleration while achieving lower loss. For more details,
54
please refer to the paper: https://openreview.net/forum?id=9VmOgMN4Ie
65
76
If you find this work useful, please cite the paper as follows:
87
9-
8+
```bibtex
109
@article{Yuan2025,
1110
title={Riemannian Fuzzy K-Means},
1211
author={Anonymous},
1312
journal={OpenReview},
1413
year={2025},
1514
url={https://openreview.net/forum?id=9VmOgMN4Ie}
1615
}
16+
```
1717
1818
If you have questions about the code, feel free to contact: yuanjinghuiiii@gmail.com.
19-
'''
19+
"""
20+
21+
from __future__ import annotations
22+
23+
from typing import Literal, Optional, Union
2024

25+
import numpy as np
2126
import torch
2227
from geoopt import ManifoldParameter
2328
from geoopt.optim import RiemannianAdam
24-
import numpy as np
29+
from jaxtyping import Float, Int
2530
from sklearn.base import BaseEstimator, ClusterMixin
26-
from ..optimizers.radan import RiemannianAdan
27-
from ..manifolds import Manifold, ProductManifold
2831

32+
from ..manifolds import Manifold, ProductManifold
33+
from ..optimizers.radan import RiemannianAdan
2934

3035

3136
class RiemannianFuzzyKMeans(BaseEstimator, ClusterMixin):
37+
"""Riemannian Fuzzy K-Means.
38+
39+
Attributes:
40+
n_clusters: The number of clusters to form.
41+
manifold: An initialized manifold object (from manifolds.py) on which clustering will be performed.
42+
m: Fuzzifier parameter. Controls the softness of the partition.
43+
lr: Learning rate for the optimizer.
44+
max_iter: Maximum number of iterations for the optimization.
45+
tol: Tolerance for convergence. If the change in loss is less than tol, iteration stops.
46+
optimizer: The optimizer to use for updating cluster centers.
47+
random_state: Seed for random number generation for reproducibility.
48+
verbose: Whether to print loss information during iterations.
49+
losses_: List of loss values during training.
50+
u_: Final fuzzy partition matrix.
51+
labels_: Cluster labels for each sample.
52+
cluster_centers_: Final cluster centers.
53+
54+
Args:
55+
n_clusters: The number of clusters to form.
56+
manifold: An initialized manifold object (from manifolds.py) on which clustering will be performed.
57+
m: Fuzzifier parameter. Controls the softness of the partition.
58+
lr: Learning rate for the optimizer.
59+
max_iter: Maximum number of iterations for the optimization.
60+
tol: Tolerance for convergence. If the change in loss is less than tol, iteration stops.
61+
optimizer: The optimizer to use for updating cluster centers.
62+
random_state: Seed for random number generation for reproducibility.
63+
verbose: Whether to print loss information during iterations.
3264
"""
33-
Riemannian Fuzzy K-Means.
34-
35-
param:
36-
----------
37-
n_clusters : int
38-
The number of clusters to form.
39-
manifold : Manifold or ProductManifold
40-
An initialized manifold object (from manifolds.py) on which clustering will be performed.
41-
m : float, default=2.0
42-
Fuzzifier parameter. Controls the softness of the partition.
43-
lr : float, default=0.1
44-
Learning rate for the optimizer.
45-
max_iter : int, default=100
46-
Maximum number of iterations for the optimization.
47-
tol : float, default=1e-4
48-
Tolerance for convergence. If the change in loss is less than tol, iteration stops.
49-
optimizer : {'adan','adam'}, default='adan'
50-
The optimizer to use for updating cluster centers.
51-
random_state : int or None, default=None
52-
Seed for random number generation for reproducibility.
53-
verbose : bool, default=False
54-
Whether to print loss information during iterations.
55-
"""
56-
def __init__(self, n_clusters, manifold, m=2.0, lr=0.1, max_iter=100,
57-
tol=1e-4, optimizer='adan',
58-
random_state=None, verbose=False):
65+
66+
def __init__(
67+
self,
68+
n_clusters: int,
69+
manifold: Union[Manifold, ProductManifold],
70+
m: float = 2.0,
71+
lr: float = 0.1,
72+
max_iter: int = 100,
73+
tol: float = 1e-4,
74+
optimizer: Literal["adan", "adam"] = "adan",
75+
random_state: Optional[int] = None,
76+
verbose: bool = False,
77+
):
5978
self.n_clusters = n_clusters
60-
self.manifold = manifold
79+
self.manifold = manifold
6180
self.m = m
6281
self.lr = lr
6382
self.max_iter = max_iter
6483
self.tol = tol
65-
if optimizer not in ('adan','adam'):
84+
if optimizer not in ("adan", "adam"):
6685
raise ValueError("optimizer must be 'adan' or 'adam'")
6786
self.optimizer = optimizer
6887
self.random_state = random_state
6988
self.verbose = verbose
7089

71-
def _init_centers(self, X):
90+
def _init_centers(self, X: Float[torch.Tensor, "n_points n_features"]):
7291
if self.random_state is not None:
7392
torch.manual_seed(self.random_state)
7493
np.random.seed(self.random_state)
@@ -91,43 +110,51 @@ def _init_centers(self, X):
91110
# If we provide self.manifold.mu0 repeated n_clusters times,
92111
# it samples n_clusters points, each around mu0.
93112
means_for_sampling_centers = self.manifold.mu0.repeat(self.n_clusters, 1)
94-
113+
95114
if isinstance(self.manifold, ProductManifold):
96115
# sigma_factorized should be a list of [n_clusters, M.dim, M.dim] tensors
97116
# Setting to None will use default identity covariances in .sample()
98-
centers, _ = self.manifold.sample(
99-
z_mean=means_for_sampling_centers,
100-
sigma_factorized=None
101-
)
117+
centers, _ = self.manifold.sample(z_mean=means_for_sampling_centers, sigma_factorized=None)
102118
elif isinstance(self.manifold, Manifold):
103119
# sigma should be a [n_clusters, self.manifold.dim, self.manifold.dim] tensor
104120
# Setting to None will use default identity covariance in .sample()
105-
centers, _ = self.manifold.sample(
106-
z_mean=means_for_sampling_centers,
107-
sigma=None
108-
)
121+
centers, _ = self.manifold.sample(z_mean=means_for_sampling_centers, sigma=None)
109122
else:
110123
# Fallback: Randomly select points from X if the manifold type isn't directly supported for sampling
111124
# This is a common k-means initialization strategy.
112125
# Ensure X is on the correct device first.
113-
X_device = X.to(self.manifold.device) # Ensure X is on the manifold's device
126+
X_device = X.to(self.manifold.device) # Ensure X is on the manifold's device
114127
indices = np.random.choice(X_device.shape[0], self.n_clusters, replace=False)
115128
centers = X_device[indices]
116129
# Ensure centers are detached if they came from X which might require grad
117130
centers = centers.detach()
118131

119-
120132
# IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121133
# as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.
122-
self.mu_ = ManifoldParameter(centers.clone().detach(), manifold=self.manifold.manifold) # Ensure centers are detached
134+
self.mu_ = ManifoldParameter(
135+
centers.clone().detach(), manifold=self.manifold.manifold
136+
) # Ensure centers are detached
123137
self.mu_.requires_grad_(True)
124138

125-
if self.optimizer == 'adan':
139+
if self.optimizer == "adan":
126140
self.opt_ = RiemannianAdan([self.mu_], lr=self.lr, betas=[0.7, 0.999, 0.999])
127141
else:
128142
self.opt_ = RiemannianAdam([self.mu_], lr=self.lr, betas=[0.99, 0.999])
129143

130-
def fit(self, X, y=None):
144+
def fit(self, X: Float[torch.Tensor, "n_points n_features"], y: None = None) -> "RiemannianFuzzyKMeans":
145+
"""Fit the Riemannian Fuzzy K-Means model to the data X.
146+
147+
Args:
148+
X: Input data. Features should match the manifold's geometry.
149+
y: Ignored, present for compatibility with scikit-learn's API.
150+
151+
Returns:
152+
self: Fitted `RiemannianFuzzyKMeans` instance.
153+
154+
Raises:
155+
ValueError: If the input data's dimension does not match the manifold's ambient dimension.
156+
RuntimeError: If the optimizer is not set correctly or if the model has not been initialized properly.
157+
"""
131158
if isinstance(X, np.ndarray):
132159
X = torch.from_numpy(X).type(torch.get_default_dtype())
133160
elif not isinstance(X, torch.Tensor):
@@ -137,7 +164,7 @@ def fit(self, X, y=None):
137164
X = X.to(self.manifold.device)
138165

139166
if X.shape[1] != self.manifold.ambient_dim:
140-
raise ValueError(
167+
raise ValueError(
141168
f"Input data X's dimension ({X.shape[1]}) in fit() does not match "
142169
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
143170
)
@@ -148,7 +175,7 @@ def fit(self, X, y=None):
148175
for i in range(self.max_iter):
149176
self.opt_.zero_grad()
150177
# self.manifold.dist is implemented in manifolds.py and handles broadcasting
151-
d = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> d is (N,K)
178+
d = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> d is (N,K)
152179
# Original RFK: d = self.manifold.dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
153180
# The .dist in manifolds.py uses X[:, None] and Y[None, :], so direct call should work if mu_ is (K,D)
154181

@@ -161,18 +188,31 @@ def fit(self, X, y=None):
161188
print(f"RFK iter {i + 1}, loss={loss.item():.4f}")
162189
if i > 0 and abs(losses[-1] - losses[-2]) < tol:
163190
break
191+
164192
# save the result
165193
self.losses_ = np.array(losses)
166-
with torch.no_grad(): # Ensure no gradients are computed for final calculations
167-
dfin = self.manifold.dist(X, self.mu_) # Re-calculate dist to final centers
168-
inv = dfin.pow(-2 / (m - 1)) + 1e-8 # Add epsilon
169-
u_final = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
194+
with torch.no_grad(): # Ensure no gradients are computed for final calculations
195+
dfin = self.manifold.dist(X, self.mu_) # Re-calculate dist to final centers
196+
inv = dfin.pow(-2 / (m - 1)) + 1e-8 # Add epsilon
197+
u_final = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
170198
self.u_ = u_final.detach().cpu().numpy()
171199
self.labels_ = np.argmax(self.u_, axis=1)
172200
self.cluster_centers_ = self.mu_.data.clone().detach().cpu().numpy()
173201
return self
174202

175-
def predict(self, X):
203+
def predict(self, X: Float[torch.Tensor, "n_points n_features"]) -> Int[torch.Tensor, "n_points"]:
204+
"""Predict the closest cluster each sample in X belongs to.
205+
206+
Args:
207+
X: Input data. Features should match the manifold's geometry.
208+
209+
Returns:
210+
labels: Cluster labels for each sample in X.
211+
212+
Raises:
213+
ValueError: If the input data's dimension does not match the manifold's ambient dimension.
214+
RuntimeError: If the model has not been fitted yet.
215+
"""
176216
if isinstance(X, np.ndarray):
177217
X = torch.from_numpy(X).type(torch.get_default_dtype())
178218
elif not isinstance(X, torch.Tensor):
@@ -187,12 +227,12 @@ def predict(self, X):
187227
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
188228
)
189229

190-
if not hasattr(self, 'mu_') or self.mu_ is None:
230+
if not hasattr(self, "mu_") or self.mu_ is None:
191231
raise RuntimeError("The RFK model has not been fitted yet. Call 'fit' before 'predict'.")
192232

193233
with torch.no_grad():
194-
dmat = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
195-
inv = dmat.pow(-2 / (self.m - 1)) + 1e-8 # Add epsilon
196-
u = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
234+
dmat = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
235+
inv = dmat.pow(-2 / (self.m - 1)) + 1e-8 # Add epsilon
236+
u = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
197237
labels = torch.argmax(u, dim=1).cpu().numpy()
198-
return labels
238+
return labels

manify/embedders/coordinate_learning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ class CoordinateLearning(BaseEmbedder):
5252
5353
Attributes:
5454
pm: Product manifold defining the target embedding space.
55-
lr: Learning rate for the optimizer.
5655
embeddings_: Optimized point coordinates after fitting.
5756
loss_history_: Training loss history.
5857
is_fitted_: Boolean flag indicating if the embedder has been fitted.

manify/optimizers/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
# -*- coding: utf-8 -*-
2-
from manify.optimizers.radan import RiemannianAdan
1+
"""New Riemannian Adan optimizer implementation."""
32

4-
__all__ = [
5-
"RiemannianAdan"
6-
]
3+
from manify.optimizers.radan import RiemannianAdan
74

5+
__all__ = ["RiemannianAdan"]

0 commit comments

Comments
 (0)