diff --git a/Python/phate/phate.py b/Python/phate/phate.py index 4fb608a..e94c576 100644 --- a/Python/phate/phate.py +++ b/Python/phate/phate.py @@ -13,6 +13,7 @@ from scipy import sparse import warnings import tasklogger +from packaging import version import matplotlib.pyplot as plt @@ -32,6 +33,15 @@ _logger = tasklogger.get_tasklogger("graphtools") +# Check graphtools version for random_landmarking support +def _graphtools_supports_random_landmarking(): + """Check if installed graphtools version supports random_landmarking parameter.""" + try: + return version.parse(graphtools.__version__) >= version.parse("2.0.0") + except AttributeError: + # graphtools doesn't have __version__, assume old version + return False + class PHATE(BaseEstimator): """PHATE operator which performs dimensionality reduction. @@ -117,6 +127,12 @@ class PHATE(BaseEstimator): If an integer is given, it fixes the seed Defaults to the global `numpy` random number generator + random_landmarking : bool, optional, default: False + Whether to use random sampling for landmarking. If True, landmarks + are selected randomly. If False, landmarks are selected deterministically + using spectral clustering. + Defaults to False. + verbose : `int` or `boolean`, optional (default: 1) If `True` or `> 0`, print status messages @@ -178,6 +194,7 @@ def __init__( mds="metric", n_jobs=1, random_state=None, + random_landmarking=False, verbose=1, **kwargs, ): @@ -201,6 +218,31 @@ def __init__( self.mds_dist = mds_dist self.mds_solver = mds_solver self.random_state = random_state + + # Validate random_landmarking parameter + if random_landmarking and n_landmark is None: + warnings.warn( + "random_landmarking=True has no effect when n_landmark=None. " + "Landmarking is disabled when n_landmark=None. " + "To use random landmarking, please set n_landmark to a positive integer " + "(e.g., n_landmark=2000).", + UserWarning + ) + # Disable random_landmarking since it has no effect + random_landmarking = False + # Check graphtools version if random_landmarking is still requested + elif random_landmarking and not _graphtools_supports_random_landmarking(): + warnings.warn( + "random_landmarking is not available in graphtools version < 2.0.0. " + "Please update graphtools to use this feature: " + "https://pypi.org/project/graphtools/2.0.0/. " + "Falling back to spectral clustering for landmark selection.", + UserWarning + ) + # Disable random_landmarking since it's not supported + random_landmarking = False + + self.random_landmarking = random_landmarking self.kwargs = kwargs self.graph = None @@ -485,6 +527,12 @@ def set_params(self, **params): If an integer is given, it fixes the seed Defaults to the global `numpy` random number generator + random_landmarking : bool, optional, default: False + Whether to use random sampling for landmarking. If True, landmarks + are selected randomly. If False, landmarks are selected deterministically + using spectral clustering. + Defaults to False. + verbose : `int` or `boolean`, optional (default: 1) If `True` or `> 0`, print status messages @@ -615,6 +663,10 @@ def set_params(self, **params): self.random_state = params["random_state"] self._set_graph_params(random_state=params["random_state"]) del params["random_state"] + if "random_landmarking" in params: + self.random_landmarking = params["random_landmarking"] + self._set_graph_params(random_landmarking=params["random_landmarking"]) + del params["random_landmarking"] if "verbose" in params: self.verbose = params["verbose"] _logger.set_level(self.verbose) @@ -751,7 +803,7 @@ def _parse_input(self, X): n_pca = self.n_pca return X, n_pca, precomputed, update_graph - def _update_graph(self, X, precomputed, n_pca, n_landmark): + def _update_graph(self, X, precomputed, n_pca, n_landmark, random_landmarking): if self.X is not None and not utils.matrix_is_equivalent(X, self.X): """ If the same data is used, we can reuse existing kernel and @@ -760,18 +812,25 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark): self._reset_graph() else: try: - self.graph.set_params( - decay=self.decay, - knn=self.knn, - knn_max=self.knn_max, - distance=self.knn_dist, - precomputed=precomputed, - n_jobs=self.n_jobs, - verbose=self.verbose, - n_pca=n_pca, - n_landmark=n_landmark, - random_state=self.random_state, - ) + # Prepare graph params + graph_params = { + 'decay': self.decay, + 'knn': self.knn, + 'knn_max': self.knn_max, + 'distance': self.knn_dist, + 'precomputed': precomputed, + 'n_jobs': self.n_jobs, + 'verbose': self.verbose, + 'n_pca': n_pca, + 'n_landmark': n_landmark, + 'random_state': self.random_state, + } + + # Only add random_landmarking if graphtools supports it + if _graphtools_supports_random_landmarking(): + graph_params['random_landmarking'] = random_landmarking + + self.graph.set_params(**graph_params) _logger.info("Using precomputed graph and diffusion operator...") except ValueError as e: # something changed that should have invalidated the graph @@ -816,27 +875,35 @@ def fit(self, X): n_landmark = self.n_landmark if self.graph is not None and update_graph: - self._update_graph(X, precomputed, n_pca, n_landmark) + self._update_graph(X, precomputed, n_pca, n_landmark, self.random_landmarking) self.X = X if self.graph is None: with _logger.log_task("graph and diffusion operator"): - self.graph = graphtools.Graph( - X, - n_pca=n_pca, - n_landmark=n_landmark, - distance=self.knn_dist, - precomputed=precomputed, - knn=self.knn, - knn_max=self.knn_max, - decay=self.decay, - thresh=1e-4, - n_jobs=self.n_jobs, - verbose=self.verbose, - random_state=self.random_state, - **(self.kwargs), - ) + # Prepare graph params + graph_params = { + 'n_pca': n_pca, + 'n_landmark': n_landmark, + 'distance': self.knn_dist, + 'precomputed': precomputed, + 'knn': self.knn, + 'knn_max': self.knn_max, + 'decay': self.decay, + 'thresh': 1e-4, + 'n_jobs': self.n_jobs, + 'verbose': self.verbose, + 'random_state': self.random_state, + } + + # Only add random_landmarking if graphtools supports it + if _graphtools_supports_random_landmarking(): + graph_params['random_landmarking'] = self.random_landmarking + + # Merge with any additional kwargs + graph_params.update(self.kwargs) + + self.graph = graphtools.Graph(X, **graph_params) # landmark op doesn't build unless forced self.diff_op