Skip to content

Commit 9870369

Browse files
authored
Merge pull request #5 from Yuan-Jinghui/main
Updated the optimizer module and the clustering module.
2 parents 70b668b + b9ef3d6 commit 9870369

File tree

6 files changed

+1225
-0
lines changed

6 files changed

+1225
-0
lines changed

manify/clustering/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
from manify.clustering.fuzzy_kmeans import RiemannianFuzzyKMeans
3+
4+
__all__ = [
5+
"RiemannianFuzzyKMeans"
6+
]
7+

manify/clustering/fuzzy_kmeans.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
'''
2+
The Riemannian Fuzzy K-Means algorithm is a clustering algorithm that operates on Riemannian manifolds.
3+
Compared to a straightforward extension of K-Means or Fuzzy K-Means to Riemannian manifolds,
4+
it offers significant acceleration while achieving lower loss. For more details,
5+
please refer to the paper: https://openreview.net/forum?id=9VmOgMN4Ie
6+
7+
If you find this work useful, please cite the paper as follows:
8+
9+
10+
@article{Yuan2025,
11+
title={Riemannian Fuzzy K-Means},
12+
author={Anonymous},
13+
journal={OpenReview},
14+
year={2025},
15+
url={https://openreview.net/forum?id=9VmOgMN4Ie}
16+
}
17+
18+
If you have questions about the code, feel free to contact: yuanjinghuiiii@gmail.com.
19+
'''
20+
21+
import torch
22+
from geoopt import ManifoldParameter
23+
from geoopt.optim import RiemannianAdam
24+
import numpy as np
25+
from sklearn.base import BaseEstimator, ClusterMixin
26+
from ..optimizers.radan import RiemannianAdan
27+
from ..manifolds import Manifold, ProductManifold
28+
29+
30+
31+
class RiemannianFuzzyKMeans(BaseEstimator, ClusterMixin):
32+
"""
33+
Riemannian Fuzzy K-Means.
34+
35+
param:
36+
----------
37+
n_clusters : int
38+
The number of clusters to form.
39+
manifold : Manifold or ProductManifold
40+
An initialized manifold object (from manifolds.py) on which clustering will be performed.
41+
m : float, default=2.0
42+
Fuzzifier parameter. Controls the softness of the partition.
43+
lr : float, default=0.1
44+
Learning rate for the optimizer.
45+
max_iter : int, default=100
46+
Maximum number of iterations for the optimization.
47+
tol : float, default=1e-4
48+
Tolerance for convergence. If the change in loss is less than tol, iteration stops.
49+
optimizer : {'adan','adam'}, default='adan'
50+
The optimizer to use for updating cluster centers.
51+
random_state : int or None, default=None
52+
Seed for random number generation for reproducibility.
53+
verbose : bool, default=False
54+
Whether to print loss information during iterations.
55+
"""
56+
def __init__(self, n_clusters, manifold, m=2.0, lr=0.1, max_iter=100,
57+
tol=1e-4, optimizer='adan',
58+
random_state=None, verbose=False):
59+
self.n_clusters = n_clusters
60+
self.manifold = manifold
61+
self.m = m
62+
self.lr = lr
63+
self.max_iter = max_iter
64+
self.tol = tol
65+
if optimizer not in ('adan','adam'):
66+
raise ValueError("optimizer must be 'adan' or 'adam'")
67+
self.optimizer = optimizer
68+
self.random_state = random_state
69+
self.verbose = verbose
70+
71+
def _init_centers(self, X):
72+
if self.random_state is not None:
73+
torch.manual_seed(self.random_state)
74+
np.random.seed(self.random_state)
75+
76+
# Input data X's second dimension should match the manifold's ambient dimension
77+
if X.shape[1] != self.manifold.ambient_dim:
78+
raise ValueError(
79+
f"Input data X's dimension ({X.shape[1]}) does not match "
80+
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
81+
)
82+
83+
# Generate initial centers using the manifold's sample method
84+
# We want n_clusters points, each sampled around the manifold's origin (mu0)
85+
# The .sample() method in manifolds.py handles z_mean and sigma/sigma_factorized
86+
# defaulting to mu0 and identity covariances if z_mean or sigma are not fully specified
87+
# or are set to None in a way that triggers this default.
88+
89+
# For sampling initial centers, we want n_clusters distinct points.
90+
# The .sample() method typically takes a z_mean of shape (num_points_to_sample, ambient_dim).
91+
# If we provide self.manifold.mu0 repeated n_clusters times,
92+
# it samples n_clusters points, each around mu0.
93+
means_for_sampling_centers = self.manifold.mu0.repeat(self.n_clusters, 1)
94+
95+
if isinstance(self.manifold, ProductManifold):
96+
# sigma_factorized should be a list of [n_clusters, M.dim, M.dim] tensors
97+
# Setting to None will use default identity covariances in .sample()
98+
centers, _ = self.manifold.sample(
99+
z_mean=means_for_sampling_centers,
100+
sigma_factorized=None
101+
)
102+
elif isinstance(self.manifold, Manifold):
103+
# sigma should be a [n_clusters, self.manifold.dim, self.manifold.dim] tensor
104+
# Setting to None will use default identity covariance in .sample()
105+
centers, _ = self.manifold.sample(
106+
z_mean=means_for_sampling_centers,
107+
sigma=None
108+
)
109+
else:
110+
# Fallback: Randomly select points from X if the manifold type isn't directly supported for sampling
111+
# This is a common k-means initialization strategy.
112+
# Ensure X is on the correct device first.
113+
X_device = X.to(self.manifold.device) # Ensure X is on the manifold's device
114+
indices = np.random.choice(X_device.shape[0], self.n_clusters, replace=False)
115+
centers = X_device[indices]
116+
# Ensure centers are detached if they came from X which might require grad
117+
centers = centers.detach()
118+
119+
120+
# IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121+
# as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.
122+
self.mu_ = ManifoldParameter(centers.clone().detach(), manifold=self.manifold.manifold) # Ensure centers are detached
123+
self.mu_.requires_grad_(True)
124+
125+
if self.optimizer == 'adan':
126+
self.opt_ = RiemannianAdan([self.mu_], lr=self.lr, betas=[0.7, 0.999, 0.999])
127+
else:
128+
self.opt_ = RiemannianAdam([self.mu_], lr=self.lr, betas=[0.99, 0.999])
129+
130+
def fit(self, X, y=None):
131+
if isinstance(X, np.ndarray):
132+
X = torch.from_numpy(X).type(torch.get_default_dtype())
133+
elif not isinstance(X, torch.Tensor):
134+
X = torch.tensor(X, dtype=torch.get_default_dtype())
135+
136+
# Ensure X is on the same device as the manifold
137+
X = X.to(self.manifold.device)
138+
139+
if X.shape[1] != self.manifold.ambient_dim:
140+
raise ValueError(
141+
f"Input data X's dimension ({X.shape[1]}) in fit() does not match "
142+
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
143+
)
144+
145+
self._init_centers(X)
146+
m, tol = self.m, self.tol
147+
losses = []
148+
for i in range(self.max_iter):
149+
self.opt_.zero_grad()
150+
# self.manifold.dist is implemented in manifolds.py and handles broadcasting
151+
d = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> d is (N,K)
152+
# Original RFK: d = self.manifold.dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
153+
# The .dist in manifolds.py uses X[:, None] and Y[None, :], so direct call should work if mu_ is (K,D)
154+
155+
S = torch.sum(d.pow(-2 / (m - 1)) + 1e-8, dim=1) # Add epsilon for stability
156+
loss = torch.sum(S.pow(1 - m))
157+
loss.backward()
158+
losses.append(loss.item())
159+
self.opt_.step()
160+
if self.verbose:
161+
print(f"RFK iter {i + 1}, loss={loss.item():.4f}")
162+
if i > 0 and abs(losses[-1] - losses[-2]) < tol:
163+
break
164+
# save the result
165+
self.losses_ = np.array(losses)
166+
with torch.no_grad(): # Ensure no gradients are computed for final calculations
167+
dfin = self.manifold.dist(X, self.mu_) # Re-calculate dist to final centers
168+
inv = dfin.pow(-2 / (m - 1)) + 1e-8 # Add epsilon
169+
u_final = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
170+
self.u_ = u_final.detach().cpu().numpy()
171+
self.labels_ = np.argmax(self.u_, axis=1)
172+
self.cluster_centers_ = self.mu_.data.clone().detach().cpu().numpy()
173+
return self
174+
175+
def predict(self, X):
176+
if isinstance(X, np.ndarray):
177+
X = torch.from_numpy(X).type(torch.get_default_dtype())
178+
elif not isinstance(X, torch.Tensor):
179+
X = torch.tensor(X, dtype=torch.get_default_dtype())
180+
181+
# Ensure X is on the same device as the manifold
182+
X = X.to(self.manifold.device)
183+
184+
if X.shape[1] != self.manifold.ambient_dim:
185+
raise ValueError(
186+
f"Input data X's dimension ({X.shape[1]}) in predict() does not match "
187+
f"the manifold's ambient dimension ({self.manifold.ambient_dim})."
188+
)
189+
190+
if not hasattr(self, 'mu_') or self.mu_ is None:
191+
raise RuntimeError("The RFK model has not been fitted yet. Call 'fit' before 'predict'.")
192+
193+
with torch.no_grad():
194+
dmat = self.manifold.dist(X, self.mu_) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
195+
inv = dmat.pow(-2 / (self.m - 1)) + 1e-8 # Add epsilon
196+
u = inv / (inv.sum(dim=1, keepdim=True) + 1e-8) # Add epsilon
197+
labels = torch.argmax(u, dim=1).cpu().numpy()
198+
return labels

manify/optimizers/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
from manify.optimizers.radan import RiemannianAdan
3+
4+
__all__ = [
5+
"RiemannianAdan"
6+
]
7+

0 commit comments

Comments
 (0)