@@ -43,7 +43,7 @@ class RiemannianFuzzyKMeans(BaseEstimator, ClusterMixin):
4343
4444 Attributes:
4545 n_clusters: The number of clusters to form.
46- manifold : An initialized manifold object (from manifolds.py) on which clustering will be performed.
46+ pm : An initialized manifold object (from manifolds.py) on which clustering will be performed.
4747 m: Fuzzifier parameter. Controls the softness of the partition.
4848 lr: Learning rate for the optimizer.
4949 max_iter: Maximum number of iterations for the optimization.
@@ -71,7 +71,7 @@ class RiemannianFuzzyKMeans(BaseEstimator, ClusterMixin):
7171 def __init__ (
7272 self ,
7373 n_clusters : int ,
74- manifold : Manifold | ProductManifold ,
74+ pm : Manifold | ProductManifold ,
7575 m : float = 2.0 ,
7676 lr : float = 0.1 ,
7777 max_iter : int = 100 ,
@@ -81,7 +81,7 @@ def __init__(
8181 verbose : bool = False ,
8282 ):
8383 self .n_clusters = n_clusters
84- self .manifold = manifold
84+ self .pm = pm
8585 self .m = m
8686 self .lr = lr
8787 self .max_iter = max_iter
@@ -97,11 +97,11 @@ def _init_centers(self, X: Float[torch.Tensor, "n_points n_features"]) -> None:
9797 torch .manual_seed (self .random_state )
9898 np .random .seed (self .random_state )
9999
100- # Input data X's second dimension should match the manifold 's ambient dimension
101- if X .shape [1 ] != self .manifold .ambient_dim :
100+ # Input data X's second dimension should match the pm 's ambient dimension
101+ if X .shape [1 ] != self .pm .ambient_dim :
102102 raise ValueError (
103103 f"Input data X's dimension ({ X .shape [1 ]} ) does not match "
104- f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
104+ f"the manifold's ambient dimension ({ self .pm .ambient_dim } )."
105105 )
106106
107107 # Generate initial centers using the manifold's sample method
@@ -112,16 +112,15 @@ def _init_centers(self, X: Float[torch.Tensor, "n_points n_features"]) -> None:
112112
113113 # For sampling initial centers, we want n_clusters distinct points.
114114 # The .sample() method typically takes a z_mean of shape (num_points_to_sample, ambient_dim).
115- # If we provide self.manifold .mu0 repeated n_clusters times,
115+ # If we provide self.pm .mu0 repeated n_clusters times,
116116 # it samples n_clusters points, each around mu0.
117- means_for_sampling_centers = self .manifold .mu0 .repeat (self .n_clusters , 1 )
118- centers = self .manifold .sample (z_mean = means_for_sampling_centers )
117+ centers = self .pm .sample (self .n_clusters )
119118
120119 # IMPORTANT: Use self.manifold.manifold for ManifoldParameter,
121120 # as self.manifold is our wrapper and self.manifold.manifold is the geoopt object.
122121 self .mu_ = ManifoldParameter (
123122 centers .clone ().detach (), # type: ignore
124- manifold = self .manifold .manifold ,
123+ manifold = self .pm .manifold ,
125124 ) # Ensure centers are detached
126125 self .mu_ .requires_grad_ (True )
127126
@@ -150,22 +149,22 @@ def fit(self, X: Float[torch.Tensor, "n_points n_features"], y: None = None) ->
150149 X = torch .tensor (X , dtype = torch .get_default_dtype ())
151150
152151 # Ensure X is on the same device as the manifold
153- X = X .to (self .manifold .device )
152+ X = X .to (self .pm .device )
154153
155- if X .shape [1 ] != self .manifold .ambient_dim :
154+ if X .shape [1 ] != self .pm .ambient_dim :
156155 raise ValueError (
157156 f"Input data X's dimension ({ X .shape [1 ]} ) in fit() does not match "
158- f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
157+ f"the manifold's ambient dimension ({ self .pm .ambient_dim } )."
159158 )
160159
161160 self ._init_centers (X )
162161 m , tol = self .m , self .tol
163162 losses = []
164163 for i in range (self .max_iter ):
165164 self .opt_ .zero_grad ()
166- # self.manifold .dist is implemented in manifolds.py and handles broadcasting
167- d = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> d is (N,K)
168- # Original RFK: d = self.manifold .dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
165+ # self.pm .dist is implemented in manifolds.py and handles broadcasting
166+ d = self .pm .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> d is (N,K)
167+ # Original RFK: d = self.pm .dist(X.unsqueeze(1), self.mu_.unsqueeze(0))
169168 # The .dist in manifolds.py uses X[:, None] and Y[None, :], so direct call should work if mu_ is (K,D)
170169
171170 S = torch .sum (d .pow (- 2 / (m - 1 )) + 1e-8 , dim = 1 ) # Add epsilon for stability
@@ -181,7 +180,7 @@ def fit(self, X: Float[torch.Tensor, "n_points n_features"], y: None = None) ->
181180 # save the result
182181 self .losses_ = np .array (losses )
183182 with torch .no_grad (): # Ensure no gradients are computed for final calculations
184- dfin = self .manifold .dist (X , self .mu_ ) # Re-calculate dist to final centers
183+ dfin = self .pm .dist (X , self .mu_ ) # Re-calculate dist to final centers
185184 inv = dfin .pow (- 2 / (m - 1 )) + 1e-8 # Add epsilon
186185 u_final = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
187186 self .u_ = u_final .detach ().cpu ().numpy ()
@@ -208,19 +207,19 @@ def predict(self, X: Float[torch.Tensor, "n_points n_features"]) -> Int[torch.Te
208207 X = torch .tensor (X , dtype = torch .get_default_dtype ())
209208
210209 # Ensure X is on the same device as the manifold
211- X = X .to (self .manifold .device )
210+ X = X .to (self .pm .device )
212211
213- if X .shape [1 ] != self .manifold .ambient_dim :
212+ if X .shape [1 ] != self .pm .ambient_dim :
214213 raise ValueError (
215214 f"Input data X's dimension ({ X .shape [1 ]} ) in predict() does not match "
216- f"the manifold's ambient dimension ({ self .manifold .ambient_dim } )."
215+ f"the manifold's ambient dimension ({ self .pm .ambient_dim } )."
217216 )
218217
219218 if not hasattr (self , "mu_" ) or self .mu_ is None :
220219 raise RuntimeError ("The RFK model has not been fitted yet. Call 'fit' before 'predict'." )
221220
222221 with torch .no_grad ():
223- dmat = self .manifold .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
222+ dmat = self .pm .dist (X , self .mu_ ) # X is (N,D), mu_ is (K,D) -> dmat is (N,K)
224223 inv = dmat .pow (- 2 / (self .m - 1 )) + 1e-8 # Add epsilon
225224 u = inv / (inv .sum (dim = 1 , keepdim = True ) + 1e-8 ) # Add epsilon
226225 labels = torch .argmax (u , dim = 1 ).cpu ().numpy ()
0 commit comments