1- '''
2- The Riemannian Fuzzy K-Means algorithm is a clustering algorithm that operates on Riemannian manifolds.
3- Compared to a straightforward extension of K-Means or Fuzzy K-Means to Riemannian manifolds,
4- it offers significant acceleration while achieving lower loss. For more details,
1+ """The Riemannian Fuzzy K-Means algorithm is a clustering algorithm that operates on Riemannian manifolds.
2+ Compared to a straightforward extension of K-Means or Fuzzy K-Means to Riemannian manifolds,
3+ it offers significant acceleration while achieving lower loss. For more details,
54please refer to the paper: https://openreview.net/forum?id=9VmOgMN4Ie
65
76If you find this work useful, please cite the paper as follows:
87
9-
8+ ```bibtex
109@article{Yuan2025,
1110 title={Riemannian Fuzzy K-Means},
1211 author={Anonymous},
1312 journal={OpenReview},
1413 year={2025},
1514 url={https://openreview.net/forum?id=9VmOgMN4Ie}
1615}
16+ ```
1717
1818If you have questions about the code, feel free to contact: yuanjinghuiiii@gmail.com.
19- '''
19+ """
20+
21+ from __future__ import annotations
22+
23+ from typing import Literal , Optional , Union
2024
25+ import numpy as np
2126import torch
2227from geoopt import ManifoldParameter
2328from geoopt .optim import RiemannianAdam
24- import numpy as np
29+ from jaxtyping import Float , Int
2530from sklearn .base import BaseEstimator , ClusterMixin
26- from ..optimizers .radan import RiemannianAdan
27- from ..manifolds import Manifold , ProductManifold
2831
32+ from ..manifolds import Manifold , ProductManifold
33+ from ..optimizers .radan import RiemannianAdan
2934
3035
3136class RiemannianFuzzyKMeans (BaseEstimator , ClusterMixin ):
37+ """Riemannian Fuzzy K-Means.
38+
39+ Attributes:
40+ n_clusters: The number of clusters to form.
41+ manifold: An initialized manifold object (from manifolds.py) on which clustering will be performed.
42+ m: Fuzzifier parameter. Controls the softness of the partition.
43+ lr: Learning rate for the optimizer.
44+ max_iter: Maximum number of iterations for the optimization.
45+ tol: Tolerance for convergence. If the change in loss is less than tol, iteration stops.
46+ optimizer: The optimizer to use for updating cluster centers.
47+ random_state: Seed for random number generation for reproducibility.
48+ verbose: Whether to print loss information during iterations.
49+ losses_: List of loss values during training.
50+ u_: Final fuzzy partition matrix.
51+ labels_: Cluster labels for each sample.
52+ cluster_centers_: Final cluster centers.
53+
54+ Args:
55+ n_clusters: The number of clusters to form.
56+ manifold: An initialized manifold object (from manifolds.py) on which clustering will be performed.
57+ m: Fuzzifier parameter. Controls the softness of the partition.
58+ lr: Learning rate for the optimizer.
59+ max_iter: Maximum number of iterations for the optimization.
60+ tol: Tolerance for convergence. If the change in loss is less than tol, iteration stops.
61+ optimizer: The optimizer to use for updating cluster centers.
62+ random_state: Seed for random number generation for reproducibility.
63+ verbose: Whether to print loss information during iterations.
3264 """
33- Riemannian Fuzzy K-Means.
34-
35- param:
36- ----------
37- n_clusters : int
38- The number of clusters to form.
39- manifold : Manifold or ProductManifold
40- An initialized manifold object (from manifolds.py) on which clustering will be performed.
41- m : float, default=2.0
42- Fuzzifier parameter. Controls the softness of the partition.
43- lr : float, default=0.1
44- Learning rate for the optimizer.
45- max_iter : int, default=100
46- Maximum number of iterations for the optimization.
47- tol : float, default=1e-4
48- Tolerance for convergence. If the change in loss is less than tol, iteration stops.
49- optimizer : {'adan','adam'}, default='adan'
50- The optimizer to use for updating cluster centers.
51- random_state : int or None, default=None
52- Seed for random number generation for reproducibility.
53- verbose : bool, default=False
54- Whether to print loss information during iterations.
55- """
56- def __init__ (self , n_clusters , manifold , m = 2.0 , lr = 0.1 , max_iter = 100 ,
57- tol = 1e-4 , optimizer = 'adan' ,
58- random_state = None , verbose = False ):
65+
66+ def __init__ (
67+ self ,
68+ n_clusters : int ,
69+ manifold : Union [Manifold , ProductManifold ],
70+ m : float = 2.0 ,
71+ lr : float = 0.1 ,
72+ max_iter : int = 100 ,
73+ tol : float = 1e-4 ,
74+ optimizer : Literal ["adan" , "adam" ] = "adan" ,
75+ random_state : Optional [int ] = None ,
76+ verbose : bool = False ,
77+ ):
5978 self .n_clusters = n_clusters
60- self .manifold = manifold
79+ self .manifold = manifold
6180 self .m = m
6281 self .lr = lr
6382 self .max_iter = max_iter
6483 self .tol = tol
65- if optimizer not in (' adan' , ' adam' ):
84+ if optimizer not in (" adan" , " adam" ):
6685 raise ValueError ("optimizer must be 'adan' or 'adam'" )
6786 self .optimizer = optimizer
6887 self .random_state = random_state
6988 self .verbose = verbose
7089
71- def _init_centers (self , X ):
90+ def _init_centers (self , X : Float [ torch . Tensor , "n_points n_features" ] ):
7291 if self .random_state is not None :
7392 torch .manual_seed (self .random_state )
7493 np .random .seed (self .random_state )
@@ -91,43 +110,51 @@ def _init_centers(self, X):
91110 # If we provide self.manifold.mu0 repeated n_clusters times,
92111 # it samples n_clusters points, each around mu0.
93112 means_for_sampling_centers = self .manifold .mu0 .repeat (self .n_clusters , 1 )
94-
113+
95114 if isinstance (self .manifold , ProductManifold ):
96115 # sigma_factorized should be a list of [n_clusters, M.dim, M.dim] tensors
97116 # Setting to None will use default identity covariances in .sample()
98- centers , _ = self .manifold .sample (
99- z_mean = means_for_sampling_centers ,
100- sigma_factorized = None
101- )
117+ centers , _ = self .manifold .sample (z_mean = means_for_sampling_centers , sigma_factorized = None )
102118 elif isinstance (self .manifold , Manifold ):
103119 # sigma should be a [n_clusters, self.manifold.dim, self.manifold.dim] tensor
104120 # Setting to None will use default identity covariance in .sample()
105- centers , _ = self .manifold .sample (
106- z_mean = means_for_sampling_centers ,
107- sigma = None
108- )
121+ centers , _ = self .manifold .sample (z_mean = means_for_sampling_centers , sigma = None )
109122 else :
110123 # Fallback: Randomly select points from X if the manifold type isn't directly supported for sampling
111124 # This is a common k-means initialization strategy.
112125 # Ensure X is on the correct device first.
113- X_device = X .to (self .manifold .device ) # Ensure X is on the manifold's device
126+ X_device = X .to (self .manifold .device ) # Ensure X is on the manifold's device
114127 indices = np .random .choice (X_device .shape [0 ], self .n_clusters , replace = False )
115128 centers = X_device [indices ]
116129 # Ensure centers are detached if they came from X which might require grad
117130 centers = centers .detach ()
118131
119-
120132 # IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121133 # as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.
122- self .mu_ = ManifoldParameter (centers .clone ().detach (), manifold = self .manifold .manifold ) # Ensure centers are detached
134+ self .mu_ = ManifoldParameter (
135+ centers .clone ().detach (), manifold = self .manifold .manifold
136+ ) # Ensure centers are detached
123137 self .mu_ .requires_grad_ (True )
124138
125- if self .optimizer == ' adan' :
139+ if self .optimizer == " adan" :
126140 self .opt_ = RiemannianAdan ([self .mu_ ], lr = self .lr , betas = [0.7 , 0.999 , 0.999 ])
127141 else :
128142 self .opt_ = RiemannianAdam ([self .mu_ ], lr = self .lr , betas = [0.99 , 0.999 ])
129143
130- def fit (self , X , y = None ):
144+ def fit (self , X : Float [torch .Tensor , "n_points n_features" ], y : None = None ) -> "RiemannianFuzzyKMeans" :
145+ """Fit the Riemannian Fuzzy K-Means model to the data X.
146+
147+ Args:
148+ X: Input data. Features should match the manifold's geometry.
149+ y: Ignored, present for compatibility with scikit-learn's API.
150+
151+ Returns:
152+ self: Fitted `RiemannianFuzzyKMeans` instance.
153+
154+ Raises:
155+ ValueError: If the input data's dimension does not match the manifold's ambient dimension.
156+ RuntimeError: If the optimizer is not set correctly or if the model has not been initialized properly.
157+ """
131158 if isinstance (X , np .ndarray ):
132159 X = torch .from_numpy (X ).type (torch .get_default_dtype ())
133160 elif not isinstance (X , torch .Tensor ):
@@ -137,7 +164,7 @@ def fit(self, X, y=None):
137164 X = X .to (self .manifold .device )
138165
139166 if X .shape [1 ] != self .manifold .ambient_dim :
140- raise ValueError (
167+ raise ValueError (
141168 f"Input data X's dimension ({ X .shape [1 ]} ) in fit() does not match "
142169 f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
143170 )
@@ -148,7 +175,7 @@ def fit(self, X, y=None):
148175 for i in range (self .max_iter ):
149176 self .opt_ .zero_grad ()
150177 # self.manifold.dist is implemented in manifolds.py and handles broadcasting
151- d = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> d is (N,K)
178+ d = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> d is (N,K)
152179 # Original RFK: d = self.manifold.dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
153180 # The .dist in manifolds.py uses X[:, None] and Y[None, :], so direct call should work if mu_ is (K,D)
154181
@@ -161,18 +188,31 @@ def fit(self, X, y=None):
161188 print (f"RFK iter { i + 1 } , loss={ loss .item ():.4f} " )
162189 if i > 0 and abs (losses [- 1 ] - losses [- 2 ]) < tol :
163190 break
191+
164192 # save the result
165193 self .losses_ = np .array (losses )
166- with torch .no_grad (): # Ensure no gradients are computed for final calculations
167- dfin = self .manifold .dist (X , self .mu_ ) # Re-calculate dist to final centers
168- inv = dfin .pow (- 2 / (m - 1 )) + 1e-8 # Add epsilon
169- u_final = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
194+ with torch .no_grad (): # Ensure no gradients are computed for final calculations
195+ dfin = self .manifold .dist (X , self .mu_ ) # Re-calculate dist to final centers
196+ inv = dfin .pow (- 2 / (m - 1 )) + 1e-8 # Add epsilon
197+ u_final = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
170198 self .u_ = u_final .detach ().cpu ().numpy ()
171199 self .labels_ = np .argmax (self .u_ , axis = 1 )
172200 self .cluster_centers_ = self .mu_ .data .clone ().detach ().cpu ().numpy ()
173201 return self
174202
175- def predict (self , X ):
203+ def predict (self , X : Float [torch .Tensor , "n_points n_features" ]) -> Int [torch .Tensor , "n_points" ]:
204+ """Predict the closest cluster each sample in X belongs to.
205+
206+ Args:
207+ X: Input data. Features should match the manifold's geometry.
208+
209+ Returns:
210+ labels: Cluster labels for each sample in X.
211+
212+ Raises:
213+ ValueError: If the input data's dimension does not match the manifold's ambient dimension.
214+ RuntimeError: If the model has not been fitted yet.
215+ """
176216 if isinstance (X , np .ndarray ):
177217 X = torch .from_numpy (X ).type (torch .get_default_dtype ())
178218 elif not isinstance (X , torch .Tensor ):
@@ -187,12 +227,12 @@ def predict(self, X):
187227 f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
188228 )
189229
190- if not hasattr (self , ' mu_' ) or self .mu_ is None :
230+ if not hasattr (self , " mu_" ) or self .mu_ is None :
191231 raise RuntimeError ("The RFK model has not been fitted yet. Call 'fit' before 'predict'." )
192232
193233 with torch .no_grad ():
194- dmat = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
195- inv = dmat .pow (- 2 / (self .m - 1 )) + 1e-8 # Add epsilon
196- u = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
234+ dmat = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
235+ inv = dmat .pow (- 2 / (self .m - 1 )) + 1e-8 # Add epsilon
236+ u = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
197237 labels = torch .argmax (u , dim = 1 ).cpu ().numpy ()
198- return labels
238+ return labels
0 commit comments