1+ '''
2+ The Riemannian Fuzzy K-Means algorithm is a clustering algorithm that operates on Riemannian manifolds.
3+ Compared to a straightforward extension of K-Means or Fuzzy K-Means to Riemannian manifolds,
4+ it offers significant acceleration while achieving lower loss. For more details,
5+ please refer to the paper: https://openreview.net/forum?id=9VmOgMN4Ie
6+
7+ If you find this work useful, please cite the paper as follows:
8+
9+
10+ @article{Yuan2025,
11+ title={Riemannian Fuzzy K-Means},
12+ author={Anonymous},
13+ journal={OpenReview},
14+ year={2025},
15+ url={https://openreview.net/forum?id=9VmOgMN4Ie}
16+ }
17+
18+ If you have questions about the code, feel free to contact: yuanjinghuiiii@gmail.com.
19+ '''
20+
21+ import torch
22+ from geoopt import ManifoldParameter
23+ from geoopt .optim import RiemannianAdam
24+ import numpy as np
25+ from sklearn .base import BaseEstimator , ClusterMixin
26+ from ..optimizers .radan import RiemannianAdan
27+ from ..manifolds import Manifold , ProductManifold
28+
29+
30+
31+ class RiemannianFuzzyKMeans (BaseEstimator , ClusterMixin ):
32+ """
33+ Riemannian Fuzzy K-Means.
34+
35+ param:
36+ ----------
37+ n_clusters : int
38+ The number of clusters to form.
39+ manifold : Manifold or ProductManifold
40+ An initialized manifold object (from manifolds.py) on which clustering will be performed.
41+ m : float, default=2.0
42+ Fuzzifier parameter. Controls the softness of the partition.
43+ lr : float, default=0.1
44+ Learning rate for the optimizer.
45+ max_iter : int, default=100
46+ Maximum number of iterations for the optimization.
47+ tol : float, default=1e-4
48+ Tolerance for convergence. If the change in loss is less than tol, iteration stops.
49+ optimizer : {'adan','adam'}, default='adan'
50+ The optimizer to use for updating cluster centers.
51+ random_state : int or None, default=None
52+ Seed for random number generation for reproducibility.
53+ verbose : bool, default=False
54+ Whether to print loss information during iterations.
55+ """
56+ def __init__ (self , n_clusters , manifold , m = 2.0 , lr = 0.1 , max_iter = 100 ,
57+ tol = 1e-4 , optimizer = 'adan' ,
58+ random_state = None , verbose = False ):
59+ self .n_clusters = n_clusters
60+ self .manifold = manifold
61+ self .m = m
62+ self .lr = lr
63+ self .max_iter = max_iter
64+ self .tol = tol
65+ if optimizer not in ('adan' ,'adam' ):
66+ raise ValueError ("optimizer must be 'adan' or 'adam'" )
67+ self .optimizer = optimizer
68+ self .random_state = random_state
69+ self .verbose = verbose
70+
71+ def _init_centers (self , X ):
72+ if self .random_state is not None :
73+ torch .manual_seed (self .random_state )
74+ np .random .seed (self .random_state )
75+
76+ # Input data X's second dimension should match the manifold's ambient dimension
77+ if X .shape [1 ] != self .manifold .ambient_dim :
78+ raise ValueError (
79+ f"Input data X's dimension ({ X .shape [1 ]} ) does not match "
80+ f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
81+ )
82+
83+ # Generate initial centers using the manifold's sample method
84+ # We want n_clusters points, each sampled around the manifold's origin (mu0)
85+ # The .sample() method in manifolds.py handles z_mean and sigma/sigma_factorized
86+ # defaulting to mu0 and identity covariances if z_mean or sigma are not fully specified
87+ # or are set to None in a way that triggers this default.
88+
89+ # For sampling initial centers, we want n_clusters distinct points.
90+ # The .sample() method typically takes a z_mean of shape (num_points_to_sample, ambient_dim).
91+ # If we provide self.manifold.mu0 repeated n_clusters times,
92+ # it samples n_clusters points, each around mu0.
93+ means_for_sampling_centers = self .manifold .mu0 .repeat (self .n_clusters , 1 )
94+
95+ if isinstance (self .manifold , ProductManifold ):
96+ # sigma_factorized should be a list of [n_clusters, M.dim, M.dim] tensors
97+ # Setting to None will use default identity covariances in .sample()
98+ centers , _ = self .manifold .sample (
99+ z_mean = means_for_sampling_centers ,
100+ sigma_factorized = None
101+ )
102+ elif isinstance (self .manifold , Manifold ):
103+ # sigma should be a [n_clusters, self.manifold.dim, self.manifold.dim] tensor
104+ # Setting to None will use default identity covariance in .sample()
105+ centers , _ = self .manifold .sample (
106+ z_mean = means_for_sampling_centers ,
107+ sigma = None
108+ )
109+ else :
110+ # Fallback: Randomly select points from X if the manifold type isn't directly supported for sampling
111+ # This is a common k-means initialization strategy.
112+ # Ensure X is on the correct device first.
113+ X_device = X .to (self .manifold .device ) # Ensure X is on the manifold's device
114+ indices = np .random .choice (X_device .shape [0 ], self .n_clusters , replace = False )
115+ centers = X_device [indices ]
116+ # Ensure centers are detached if they came from X which might require grad
117+ centers = centers .detach ()
118+
119+
120+ # IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121+ # as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.
122+ self .mu_ = ManifoldParameter (centers .clone ().detach (), manifold = self .manifold .manifold ) # Ensure centers are detached
123+ self .mu_ .requires_grad_ (True )
124+
125+ if self .optimizer == 'adan' :
126+ self .opt_ = RiemannianAdan ([self .mu_ ], lr = self .lr , betas = [0.7 , 0.999 , 0.999 ])
127+ else :
128+ self .opt_ = RiemannianAdam ([self .mu_ ], lr = self .lr , betas = [0.99 , 0.999 ])
129+
130+ def fit (self , X , y = None ):
131+ if isinstance (X , np .ndarray ):
132+ X = torch .from_numpy (X ).type (torch .get_default_dtype ())
133+ elif not isinstance (X , torch .Tensor ):
134+ X = torch .tensor (X , dtype = torch .get_default_dtype ())
135+
136+ # Ensure X is on the same device as the manifold
137+ X = X .to (self .manifold .device )
138+
139+ if X .shape [1 ] != self .manifold .ambient_dim :
140+ raise ValueError (
141+ f"Input data X's dimension ({ X .shape [1 ]} ) in fit() does not match "
142+ f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
143+ )
144+
145+ self ._init_centers (X )
146+ m , tol = self .m , self .tol
147+ losses = []
148+ for i in range (self .max_iter ):
149+ self .opt_ .zero_grad ()
150+ # self.manifold.dist is implemented in manifolds.py and handles broadcasting
151+ d = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> d is (N,K)
152+ # Original RFK: d = self.manifold.dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
153+ # The .dist in manifolds.py uses X[:, None] and Y[None, :], so direct call should work if mu_ is (K,D)
154+
155+ S = torch .sum (d .pow (- 2 / (m - 1 )) + 1e-8 , dim = 1 ) # Add epsilon for stability
156+ loss = torch .sum (S .pow (1 - m ))
157+ loss .backward ()
158+ losses .append (loss .item ())
159+ self .opt_ .step ()
160+ if self .verbose :
161+ print (f"RFK iter { i + 1 } , loss={ loss .item ():.4f} " )
162+ if i > 0 and abs (losses [- 1 ] - losses [- 2 ]) < tol :
163+ break
164+ # save the result
165+ self .losses_ = np .array (losses )
166+ with torch .no_grad (): # Ensure no gradients are computed for final calculations
167+ dfin = self .manifold .dist (X , self .mu_ ) # Re-calculate dist to final centers
168+ inv = dfin .pow (- 2 / (m - 1 )) + 1e-8 # Add epsilon
169+ u_final = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
170+ self .u_ = u_final .detach ().cpu ().numpy ()
171+ self .labels_ = np .argmax (self .u_ , axis = 1 )
172+ self .cluster_centers_ = self .mu_ .data .clone ().detach ().cpu ().numpy ()
173+ return self
174+
175+ def predict (self , X ):
176+ if isinstance (X , np .ndarray ):
177+ X = torch .from_numpy (X ).type (torch .get_default_dtype ())
178+ elif not isinstance (X , torch .Tensor ):
179+ X = torch .tensor (X , dtype = torch .get_default_dtype ())
180+
181+ # Ensure X is on the same device as the manifold
182+ X = X .to (self .manifold .device )
183+
184+ if X .shape [1 ] != self .manifold .ambient_dim :
185+ raise ValueError (
186+ f"Input data X's dimension ({ X .shape [1 ]} ) in predict() does not match "
187+ f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
188+ )
189+
190+ if not hasattr (self , 'mu_' ) or self .mu_ is None :
191+ raise RuntimeError ("The RFK model has not been fitted yet. Call 'fit' before 'predict'." )
192+
193+ with torch .no_grad ():
194+ dmat = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
195+ inv = dmat .pow (- 2 / (self .m - 1 )) + 1e-8 # Add epsilon
196+ u = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
197+ labels = torch .argmax (u , dim = 1 ).cpu ().numpy ()
198+ return labels
0 commit comments