1313from scipy import sparse
1414import warnings
1515import tasklogger
16+ from packaging import version
1617
1718import matplotlib .pyplot as plt
1819
3233
3334_logger = tasklogger .get_tasklogger ("graphtools" )
3435
36+ # Check graphtools version for random_landmarking support
37+ def _graphtools_supports_random_landmarking ():
38+ """Check if installed graphtools version supports random_landmarking parameter."""
39+ try :
40+ return version .parse (graphtools .__version__ ) >= version .parse ("2.0.0" )
41+ except AttributeError :
42+ # graphtools doesn't have __version__, assume old version
43+ return False
44+
3545
3646class PHATE (BaseEstimator ):
3747 """PHATE operator which performs dimensionality reduction.
@@ -117,6 +127,12 @@ class PHATE(BaseEstimator):
117127 If an integer is given, it fixes the seed
118128 Defaults to the global `numpy` random number generator
119129
130+ random_landmarking : bool, optional, default: False
131+ Whether to use random sampling for landmarking. If True, landmarks
132+ are selected randomly. If False, landmarks are selected deterministically
133+ using spectral clustering.
134+ Defaults to False.
135+
120136 verbose : `int` or `boolean`, optional (default: 1)
121137 If `True` or `> 0`, print status messages
122138
@@ -178,6 +194,7 @@ def __init__(
178194 mds = "metric" ,
179195 n_jobs = 1 ,
180196 random_state = None ,
197+ random_landmarking = False ,
181198 verbose = 1 ,
182199 ** kwargs ,
183200 ):
@@ -201,6 +218,31 @@ def __init__(
201218 self .mds_dist = mds_dist
202219 self .mds_solver = mds_solver
203220 self .random_state = random_state
221+
222+ # Validate random_landmarking parameter
223+ if random_landmarking and n_landmark is None :
224+ warnings .warn (
225+ "random_landmarking=True has no effect when n_landmark=None. "
226+ "Landmarking is disabled when n_landmark=None. "
227+ "To use random landmarking, please set n_landmark to a positive integer "
228+ "(e.g., n_landmark=2000)." ,
229+ UserWarning
230+ )
231+ # Disable random_landmarking since it has no effect
232+ random_landmarking = False
233+ # Check graphtools version if random_landmarking is still requested
234+ elif random_landmarking and not _graphtools_supports_random_landmarking ():
235+ warnings .warn (
236+ "random_landmarking is not available in graphtools version < 2.0.0. "
237+ "Please update graphtools to use this feature: "
238+ "https://pypi.org/project/graphtools/2.0.0/. "
239+ "Falling back to spectral clustering for landmark selection." ,
240+ UserWarning
241+ )
242+ # Disable random_landmarking since it's not supported
243+ random_landmarking = False
244+
245+ self .random_landmarking = random_landmarking
204246 self .kwargs = kwargs
205247
206248 self .graph = None
@@ -485,6 +527,12 @@ def set_params(self, **params):
485527 If an integer is given, it fixes the seed
486528 Defaults to the global `numpy` random number generator
487529
530+ random_landmarking : bool, optional, default: False
531+ Whether to use random sampling for landmarking. If True, landmarks
532+ are selected randomly. If False, landmarks are selected deterministically
533+ using spectral clustering.
534+ Defaults to False.
535+
488536 verbose : `int` or `boolean`, optional (default: 1)
489537 If `True` or `> 0`, print status messages
490538
@@ -615,6 +663,10 @@ def set_params(self, **params):
615663 self .random_state = params ["random_state" ]
616664 self ._set_graph_params (random_state = params ["random_state" ])
617665 del params ["random_state" ]
666+ if "random_landmarking" in params :
667+ self .random_landmarking = params ["random_landmarking" ]
668+ self ._set_graph_params (random_landmarking = params ["random_landmarking" ])
669+ del params ["random_landmarking" ]
618670 if "verbose" in params :
619671 self .verbose = params ["verbose" ]
620672 _logger .set_level (self .verbose )
@@ -751,7 +803,7 @@ def _parse_input(self, X):
751803 n_pca = self .n_pca
752804 return X , n_pca , precomputed , update_graph
753805
754- def _update_graph (self , X , precomputed , n_pca , n_landmark ):
806+ def _update_graph (self , X , precomputed , n_pca , n_landmark , random_landmarking ):
755807 if self .X is not None and not utils .matrix_is_equivalent (X , self .X ):
756808 """
757809 If the same data is used, we can reuse existing kernel and
@@ -760,18 +812,25 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark):
760812 self ._reset_graph ()
761813 else :
762814 try :
763- self .graph .set_params (
764- decay = self .decay ,
765- knn = self .knn ,
766- knn_max = self .knn_max ,
767- distance = self .knn_dist ,
768- precomputed = precomputed ,
769- n_jobs = self .n_jobs ,
770- verbose = self .verbose ,
771- n_pca = n_pca ,
772- n_landmark = n_landmark ,
773- random_state = self .random_state ,
774- )
815+ # Prepare graph params
816+ graph_params = {
817+ 'decay' : self .decay ,
818+ 'knn' : self .knn ,
819+ 'knn_max' : self .knn_max ,
820+ 'distance' : self .knn_dist ,
821+ 'precomputed' : precomputed ,
822+ 'n_jobs' : self .n_jobs ,
823+ 'verbose' : self .verbose ,
824+ 'n_pca' : n_pca ,
825+ 'n_landmark' : n_landmark ,
826+ 'random_state' : self .random_state ,
827+ }
828+
829+ # Only add random_landmarking if graphtools supports it
830+ if _graphtools_supports_random_landmarking ():
831+ graph_params ['random_landmarking' ] = random_landmarking
832+
833+ self .graph .set_params (** graph_params )
775834 _logger .info ("Using precomputed graph and diffusion operator..." )
776835 except ValueError as e :
777836 # something changed that should have invalidated the graph
@@ -816,27 +875,35 @@ def fit(self, X):
816875 n_landmark = self .n_landmark
817876
818877 if self .graph is not None and update_graph :
819- self ._update_graph (X , precomputed , n_pca , n_landmark )
878+ self ._update_graph (X , precomputed , n_pca , n_landmark , self . random_landmarking )
820879
821880 self .X = X
822881
823882 if self .graph is None :
824883 with _logger .log_task ("graph and diffusion operator" ):
825- self .graph = graphtools .Graph (
826- X ,
827- n_pca = n_pca ,
828- n_landmark = n_landmark ,
829- distance = self .knn_dist ,
830- precomputed = precomputed ,
831- knn = self .knn ,
832- knn_max = self .knn_max ,
833- decay = self .decay ,
834- thresh = 1e-4 ,
835- n_jobs = self .n_jobs ,
836- verbose = self .verbose ,
837- random_state = self .random_state ,
838- ** (self .kwargs ),
839- )
884+ # Prepare graph params
885+ graph_params = {
886+ 'n_pca' : n_pca ,
887+ 'n_landmark' : n_landmark ,
888+ 'distance' : self .knn_dist ,
889+ 'precomputed' : precomputed ,
890+ 'knn' : self .knn ,
891+ 'knn_max' : self .knn_max ,
892+ 'decay' : self .decay ,
893+ 'thresh' : 1e-4 ,
894+ 'n_jobs' : self .n_jobs ,
895+ 'verbose' : self .verbose ,
896+ 'random_state' : self .random_state ,
897+ }
898+
899+ # Only add random_landmarking if graphtools supports it
900+ if _graphtools_supports_random_landmarking ():
901+ graph_params ['random_landmarking' ] = self .random_landmarking
902+
903+ # Merge with any additional kwargs
904+ graph_params .update (self .kwargs )
905+
906+ self .graph = graphtools .Graph (X , ** graph_params )
840907
841908 # landmark op doesn't build unless forced
842909 self .diff_op
0 commit comments