3333
3434_logger = tasklogger .get_tasklogger ("graphtools" )
3535
36+
3637# Check graphtools version
3738def _graphtools_version_is_at_least_2_0 ():
3839 """Check if installed graphtools version is >= 2.0.0.
@@ -132,9 +133,9 @@ class PHATE(BaseEstimator):
132133 If an integer is given, it fixes the seed
133134 Defaults to the global `numpy` random number generator
134135
135- random_landmarking : bool, optional, default: False
136- Whether to use random sampling for landmarking. If True, landmarks
137- are selected randomly. If False, landmarks are selected deterministically
136+ random_landmarking : bool, optional, default: False
137+ Whether to use random sampling for landmarking. If True, landmarks
138+ are selected randomly. If False, landmarks are selected deterministically
138139 using spectral clustering.
139140 Defaults to False.
140141
@@ -231,7 +232,7 @@ def __init__(
231232 "Landmarking is disabled when n_landmark=None. "
232233 "To use random landmarking, please set n_landmark to a positive integer "
233234 "(e.g., n_landmark=2000)." ,
234- UserWarning
235+ UserWarning ,
235236 )
236237 # Disable random_landmarking since it has no effect
237238 random_landmarking = False
@@ -242,7 +243,7 @@ def __init__(
242243 "Please update graphtools to use this feature: "
243244 "https://pypi.org/project/graphtools/2.0.0/. "
244245 "Falling back to spectral clustering for landmark selection." ,
245- UserWarning
246+ UserWarning ,
246247 )
247248 # Disable random_landmarking since it's not supported
248249 random_landmarking = False
@@ -532,9 +533,9 @@ def set_params(self, **params):
532533 If an integer is given, it fixes the seed
533534 Defaults to the global `numpy` random number generator
534535
535- random_landmarking : bool, optional, default: False
536- Whether to use random sampling for landmarking. If True, landmarks
537- are selected randomly. If False, landmarks are selected deterministically
536+ random_landmarking : bool, optional, default: False
537+ Whether to use random sampling for landmarking. If True, landmarks
538+ are selected randomly. If False, landmarks are selected deterministically
538539 using spectral clustering.
539540 Defaults to False.
540541
@@ -819,21 +820,21 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark, random_landmarking):
819820 try :
820821 # Prepare graph params
821822 graph_params = {
822- ' decay' : self .decay ,
823- ' knn' : self .knn ,
824- ' knn_max' : self .knn_max ,
825- ' distance' : self .knn_dist ,
826- ' precomputed' : precomputed ,
827- ' n_jobs' : self .n_jobs ,
828- ' verbose' : self .verbose ,
829- ' n_pca' : n_pca ,
830- ' n_landmark' : n_landmark ,
831- ' random_state' : self .random_state ,
823+ " decay" : self .decay ,
824+ " knn" : self .knn ,
825+ " knn_max" : self .knn_max ,
826+ " distance" : self .knn_dist ,
827+ " precomputed" : precomputed ,
828+ " n_jobs" : self .n_jobs ,
829+ " verbose" : self .verbose ,
830+ " n_pca" : n_pca ,
831+ " n_landmark" : n_landmark ,
832+ " random_state" : self .random_state ,
832833 }
833834
834835 # Only add random_landmarking if graphtools supports it
835836 if _graphtools_version_is_at_least_2_0 ():
836- graph_params [' random_landmarking' ] = random_landmarking
837+ graph_params [" random_landmarking" ] = random_landmarking
837838
838839 self .graph .set_params (** graph_params )
839840 _logger .log_info ("Using precomputed graph and diffusion operator..." )
@@ -880,30 +881,32 @@ def fit(self, X):
880881 n_landmark = self .n_landmark
881882
882883 if self .graph is not None and update_graph :
883- self ._update_graph (X , precomputed , n_pca , n_landmark , self .random_landmarking )
884+ self ._update_graph (
885+ X , precomputed , n_pca , n_landmark , self .random_landmarking
886+ )
884887
885888 self .X = X
886889
887890 if self .graph is None :
888891 with _logger .log_task ("graph and diffusion operator" ):
889892 # Prepare graph params
890893 graph_params = {
891- ' n_pca' : n_pca ,
892- ' n_landmark' : n_landmark ,
893- ' distance' : self .knn_dist ,
894- ' precomputed' : precomputed ,
895- ' knn' : self .knn ,
896- ' knn_max' : self .knn_max ,
897- ' decay' : self .decay ,
898- ' thresh' : 1e-4 ,
899- ' n_jobs' : self .n_jobs ,
900- ' verbose' : self .verbose ,
901- ' random_state' : self .random_state ,
894+ " n_pca" : n_pca ,
895+ " n_landmark" : n_landmark ,
896+ " distance" : self .knn_dist ,
897+ " precomputed" : precomputed ,
898+ " knn" : self .knn ,
899+ " knn_max" : self .knn_max ,
900+ " decay" : self .decay ,
901+ " thresh" : 1e-4 ,
902+ " n_jobs" : self .n_jobs ,
903+ " verbose" : self .verbose ,
904+ " random_state" : self .random_state ,
902905 }
903906
904907 # Only add random_landmarking if graphtools supports it
905908 if _graphtools_version_is_at_least_2_0 ():
906- graph_params [' random_landmarking' ] = self .random_landmarking
909+ graph_params [" random_landmarking" ] = self .random_landmarking
907910
908911 # Merge with any additional kwargs
909912 graph_params .update (self .kwargs )
0 commit comments