Skip to content

Commit dd1bddd

Browse files
authored
Merge pull request #161 from MattScicluna/add_random_landmarking
Add random landmarking
2 parents 3b68c3c + 0e03523 commit dd1bddd

File tree

1 file changed

+96
-29
lines changed

1 file changed

+96
-29
lines changed

Python/phate/phate.py

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from scipy import sparse
1414
import warnings
1515
import tasklogger
16+
from packaging import version
1617

1718
import matplotlib.pyplot as plt
1819

@@ -32,6 +33,15 @@
3233

3334
_logger = tasklogger.get_tasklogger("graphtools")
3435

36+
# Check graphtools version for random_landmarking support
37+
def _graphtools_supports_random_landmarking():
38+
"""Check if installed graphtools version supports random_landmarking parameter."""
39+
try:
40+
return version.parse(graphtools.__version__) >= version.parse("2.0.0")
41+
except AttributeError:
42+
# graphtools doesn't have __version__, assume old version
43+
return False
44+
3545

3646
class PHATE(BaseEstimator):
3747
"""PHATE operator which performs dimensionality reduction.
@@ -117,6 +127,12 @@ class PHATE(BaseEstimator):
117127
If an integer is given, it fixes the seed
118128
Defaults to the global `numpy` random number generator
119129
130+
random_landmarking : bool, optional, default: False
131+
Whether to use random sampling for landmarking. If True, landmarks
132+
are selected randomly. If False, landmarks are selected deterministically
133+
using spectral clustering.
134+
Defaults to False.
135+
120136
verbose : `int` or `boolean`, optional (default: 1)
121137
If `True` or `> 0`, print status messages
122138
@@ -178,6 +194,7 @@ def __init__(
178194
mds="metric",
179195
n_jobs=1,
180196
random_state=None,
197+
random_landmarking=False,
181198
verbose=1,
182199
**kwargs,
183200
):
@@ -201,6 +218,31 @@ def __init__(
201218
self.mds_dist = mds_dist
202219
self.mds_solver = mds_solver
203220
self.random_state = random_state
221+
222+
# Validate random_landmarking parameter
223+
if random_landmarking and n_landmark is None:
224+
warnings.warn(
225+
"random_landmarking=True has no effect when n_landmark=None. "
226+
"Landmarking is disabled when n_landmark=None. "
227+
"To use random landmarking, please set n_landmark to a positive integer "
228+
"(e.g., n_landmark=2000).",
229+
UserWarning
230+
)
231+
# Disable random_landmarking since it has no effect
232+
random_landmarking = False
233+
# Check graphtools version if random_landmarking is still requested
234+
elif random_landmarking and not _graphtools_supports_random_landmarking():
235+
warnings.warn(
236+
"random_landmarking is not available in graphtools version < 2.0.0. "
237+
"Please update graphtools to use this feature: "
238+
"https://pypi.org/project/graphtools/2.0.0/. "
239+
"Falling back to spectral clustering for landmark selection.",
240+
UserWarning
241+
)
242+
# Disable random_landmarking since it's not supported
243+
random_landmarking = False
244+
245+
self.random_landmarking = random_landmarking
204246
self.kwargs = kwargs
205247

206248
self.graph = None
@@ -485,6 +527,12 @@ def set_params(self, **params):
485527
If an integer is given, it fixes the seed
486528
Defaults to the global `numpy` random number generator
487529
530+
random_landmarking : bool, optional, default: False
531+
Whether to use random sampling for landmarking. If True, landmarks
532+
are selected randomly. If False, landmarks are selected deterministically
533+
using spectral clustering.
534+
Defaults to False.
535+
488536
verbose : `int` or `boolean`, optional (default: 1)
489537
If `True` or `> 0`, print status messages
490538
@@ -615,6 +663,10 @@ def set_params(self, **params):
615663
self.random_state = params["random_state"]
616664
self._set_graph_params(random_state=params["random_state"])
617665
del params["random_state"]
666+
if "random_landmarking" in params:
667+
self.random_landmarking = params["random_landmarking"]
668+
self._set_graph_params(random_landmarking=params["random_landmarking"])
669+
del params["random_landmarking"]
618670
if "verbose" in params:
619671
self.verbose = params["verbose"]
620672
_logger.set_level(self.verbose)
@@ -751,7 +803,7 @@ def _parse_input(self, X):
751803
n_pca = self.n_pca
752804
return X, n_pca, precomputed, update_graph
753805

754-
def _update_graph(self, X, precomputed, n_pca, n_landmark):
806+
def _update_graph(self, X, precomputed, n_pca, n_landmark, random_landmarking):
755807
if self.X is not None and not utils.matrix_is_equivalent(X, self.X):
756808
"""
757809
If the same data is used, we can reuse existing kernel and
@@ -760,18 +812,25 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark):
760812
self._reset_graph()
761813
else:
762814
try:
763-
self.graph.set_params(
764-
decay=self.decay,
765-
knn=self.knn,
766-
knn_max=self.knn_max,
767-
distance=self.knn_dist,
768-
precomputed=precomputed,
769-
n_jobs=self.n_jobs,
770-
verbose=self.verbose,
771-
n_pca=n_pca,
772-
n_landmark=n_landmark,
773-
random_state=self.random_state,
774-
)
815+
# Prepare graph params
816+
graph_params = {
817+
'decay': self.decay,
818+
'knn': self.knn,
819+
'knn_max': self.knn_max,
820+
'distance': self.knn_dist,
821+
'precomputed': precomputed,
822+
'n_jobs': self.n_jobs,
823+
'verbose': self.verbose,
824+
'n_pca': n_pca,
825+
'n_landmark': n_landmark,
826+
'random_state': self.random_state,
827+
}
828+
829+
# Only add random_landmarking if graphtools supports it
830+
if _graphtools_supports_random_landmarking():
831+
graph_params['random_landmarking'] = random_landmarking
832+
833+
self.graph.set_params(**graph_params)
775834
_logger.info("Using precomputed graph and diffusion operator...")
776835
except ValueError as e:
777836
# something changed that should have invalidated the graph
@@ -816,27 +875,35 @@ def fit(self, X):
816875
n_landmark = self.n_landmark
817876

818877
if self.graph is not None and update_graph:
819-
self._update_graph(X, precomputed, n_pca, n_landmark)
878+
self._update_graph(X, precomputed, n_pca, n_landmark, self.random_landmarking)
820879

821880
self.X = X
822881

823882
if self.graph is None:
824883
with _logger.log_task("graph and diffusion operator"):
825-
self.graph = graphtools.Graph(
826-
X,
827-
n_pca=n_pca,
828-
n_landmark=n_landmark,
829-
distance=self.knn_dist,
830-
precomputed=precomputed,
831-
knn=self.knn,
832-
knn_max=self.knn_max,
833-
decay=self.decay,
834-
thresh=1e-4,
835-
n_jobs=self.n_jobs,
836-
verbose=self.verbose,
837-
random_state=self.random_state,
838-
**(self.kwargs),
839-
)
884+
# Prepare graph params
885+
graph_params = {
886+
'n_pca': n_pca,
887+
'n_landmark': n_landmark,
888+
'distance': self.knn_dist,
889+
'precomputed': precomputed,
890+
'knn': self.knn,
891+
'knn_max': self.knn_max,
892+
'decay': self.decay,
893+
'thresh': 1e-4,
894+
'n_jobs': self.n_jobs,
895+
'verbose': self.verbose,
896+
'random_state': self.random_state,
897+
}
898+
899+
# Only add random_landmarking if graphtools supports it
900+
if _graphtools_supports_random_landmarking():
901+
graph_params['random_landmarking'] = self.random_landmarking
902+
903+
# Merge with any additional kwargs
904+
graph_params.update(self.kwargs)
905+
906+
self.graph = graphtools.Graph(X, **graph_params)
840907

841908
# landmark op doesn't build unless forced
842909
self.diff_op

0 commit comments

Comments
 (0)