Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 96 additions & 29 deletions Python/phate/phate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scipy import sparse
import warnings
import tasklogger
from packaging import version

import matplotlib.pyplot as plt

Expand All @@ -32,6 +33,15 @@

_logger = tasklogger.get_tasklogger("graphtools")

# Check graphtools version for random_landmarking support
def _graphtools_supports_random_landmarking():
"""Check if installed graphtools version supports random_landmarking parameter."""
try:
return version.parse(graphtools.__version__) >= version.parse("2.0.0")
except AttributeError:
# graphtools doesn't have __version__, assume old version
return False


class PHATE(BaseEstimator):
"""PHATE operator which performs dimensionality reduction.
Expand Down Expand Up @@ -117,6 +127,12 @@ class PHATE(BaseEstimator):
If an integer is given, it fixes the seed
Defaults to the global `numpy` random number generator

random_landmarking : bool, optional, default: False
Whether to use random sampling for landmarking. If True, landmarks
are selected randomly. If False, landmarks are selected deterministically
using spectral clustering.
Defaults to False.

verbose : `int` or `boolean`, optional (default: 1)
If `True` or `> 0`, print status messages

Expand Down Expand Up @@ -178,6 +194,7 @@ def __init__(
mds="metric",
n_jobs=1,
random_state=None,
random_landmarking=False,
verbose=1,
**kwargs,
):
Expand All @@ -201,6 +218,31 @@ def __init__(
self.mds_dist = mds_dist
self.mds_solver = mds_solver
self.random_state = random_state

# Validate random_landmarking parameter
if random_landmarking and n_landmark is None:
warnings.warn(
"random_landmarking=True has no effect when n_landmark=None. "
"Landmarking is disabled when n_landmark=None. "
"To use random landmarking, please set n_landmark to a positive integer "
"(e.g., n_landmark=2000).",
UserWarning
)
# Disable random_landmarking since it has no effect
random_landmarking = False
# Check graphtools version if random_landmarking is still requested
elif random_landmarking and not _graphtools_supports_random_landmarking():
warnings.warn(
"random_landmarking is not available in graphtools version < 2.0.0. "
"Please update graphtools to use this feature: "
"https://pypi.org/project/graphtools/2.0.0/. "
"Falling back to spectral clustering for landmark selection.",
UserWarning
)
# Disable random_landmarking since it's not supported
random_landmarking = False

self.random_landmarking = random_landmarking
self.kwargs = kwargs

self.graph = None
Expand Down Expand Up @@ -485,6 +527,12 @@ def set_params(self, **params):
If an integer is given, it fixes the seed
Defaults to the global `numpy` random number generator

random_landmarking : bool, optional, default: False
Whether to use random sampling for landmarking. If True, landmarks
are selected randomly. If False, landmarks are selected deterministically
using spectral clustering.
Defaults to False.

verbose : `int` or `boolean`, optional (default: 1)
If `True` or `> 0`, print status messages

Expand Down Expand Up @@ -615,6 +663,10 @@ def set_params(self, **params):
self.random_state = params["random_state"]
self._set_graph_params(random_state=params["random_state"])
del params["random_state"]
if "random_landmarking" in params:
self.random_landmarking = params["random_landmarking"]
self._set_graph_params(random_landmarking=params["random_landmarking"])
del params["random_landmarking"]
if "verbose" in params:
self.verbose = params["verbose"]
_logger.set_level(self.verbose)
Expand Down Expand Up @@ -751,7 +803,7 @@ def _parse_input(self, X):
n_pca = self.n_pca
return X, n_pca, precomputed, update_graph

def _update_graph(self, X, precomputed, n_pca, n_landmark):
def _update_graph(self, X, precomputed, n_pca, n_landmark, random_landmarking):
if self.X is not None and not utils.matrix_is_equivalent(X, self.X):
"""
If the same data is used, we can reuse existing kernel and
Expand All @@ -760,18 +812,25 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark):
self._reset_graph()
else:
try:
self.graph.set_params(
decay=self.decay,
knn=self.knn,
knn_max=self.knn_max,
distance=self.knn_dist,
precomputed=precomputed,
n_jobs=self.n_jobs,
verbose=self.verbose,
n_pca=n_pca,
n_landmark=n_landmark,
random_state=self.random_state,
)
# Prepare graph params
graph_params = {
'decay': self.decay,
'knn': self.knn,
'knn_max': self.knn_max,
'distance': self.knn_dist,
'precomputed': precomputed,
'n_jobs': self.n_jobs,
'verbose': self.verbose,
'n_pca': n_pca,
'n_landmark': n_landmark,
'random_state': self.random_state,
}

# Only add random_landmarking if graphtools supports it
if _graphtools_supports_random_landmarking():
graph_params['random_landmarking'] = random_landmarking

self.graph.set_params(**graph_params)
_logger.info("Using precomputed graph and diffusion operator...")
except ValueError as e:
# something changed that should have invalidated the graph
Expand Down Expand Up @@ -816,27 +875,35 @@ def fit(self, X):
n_landmark = self.n_landmark

if self.graph is not None and update_graph:
self._update_graph(X, precomputed, n_pca, n_landmark)
self._update_graph(X, precomputed, n_pca, n_landmark, self.random_landmarking)

self.X = X

if self.graph is None:
with _logger.log_task("graph and diffusion operator"):
self.graph = graphtools.Graph(
X,
n_pca=n_pca,
n_landmark=n_landmark,
distance=self.knn_dist,
precomputed=precomputed,
knn=self.knn,
knn_max=self.knn_max,
decay=self.decay,
thresh=1e-4,
n_jobs=self.n_jobs,
verbose=self.verbose,
random_state=self.random_state,
**(self.kwargs),
)
# Prepare graph params
graph_params = {
'n_pca': n_pca,
'n_landmark': n_landmark,
'distance': self.knn_dist,
'precomputed': precomputed,
'knn': self.knn,
'knn_max': self.knn_max,
'decay': self.decay,
'thresh': 1e-4,
'n_jobs': self.n_jobs,
'verbose': self.verbose,
'random_state': self.random_state,
}

# Only add random_landmarking if graphtools supports it
if _graphtools_supports_random_landmarking():
graph_params['random_landmarking'] = self.random_landmarking

# Merge with any additional kwargs
graph_params.update(self.kwargs)

self.graph = graphtools.Graph(X, **graph_params)

# landmark op doesn't build unless forced
self.diff_op
Expand Down
Loading