Skip to content

Commit 80073f4

Browse files
author
Joao Felipe Rocha
committed
Merge remote-tracking branch 'origin/main' into add_pytests
2 parents bb74efd + d5de920 commit 80073f4

File tree

5 files changed

+199
-41
lines changed

5 files changed

+199
-41
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Python/doc/build/
5353

5454
# Jupyter
5555
.ipynb_checkpoints/
56+
Python/tutorial/cache
5657

5758
# Mac
5859
.DS_Store

Python/phate/mds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def embed_MDS(
218218
# which is much faster than scipy's pdist + squareform
219219
if distance_metric == "euclidean" and X.shape[0] > 1000:
220220
from sklearn.metrics.pairwise import euclidean_distances
221+
221222
X_dist = euclidean_distances(X, X)
222223
else:
223224
X_dist = squareform(pdist(X, distance_metric))
@@ -235,7 +236,7 @@ def embed_MDS(
235236
n_components=ndim,
236237
random_state=seed,
237238
init=Y_classic,
238-
verbose=verbose
239+
verbose=verbose,
239240
)
240241
elif solver == "smacof":
241242
Y = smacof(

Python/phate/phate.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,15 @@
3333

3434
_logger = tasklogger.get_tasklogger("graphtools")
3535

36-
# Check graphtools version for random_landmarking support
37-
def _graphtools_supports_random_landmarking():
38-
"""Check if installed graphtools version supports random_landmarking parameter."""
36+
37+
# Check graphtools version
38+
def _graphtools_version_is_at_least_2_0():
39+
"""Check if installed graphtools version is >= 2.0.0.
40+
41+
Version 2.0.0+ includes support for:
42+
- random_landmarking parameter
43+
- is_connected property and connectivity checks
44+
"""
3945
try:
4046
return version.parse(graphtools.__version__) >= version.parse("2.0.0")
4147
except AttributeError:
@@ -127,9 +133,9 @@ class PHATE(BaseEstimator):
127133
If an integer is given, it fixes the seed
128134
Defaults to the global `numpy` random number generator
129135
130-
random_landmarking : bool, optional, default: False
131-
Whether to use random sampling for landmarking. If True, landmarks
132-
are selected randomly. If False, landmarks are selected deterministically
136+
random_landmarking : bool, optional, default: False
137+
Whether to use random sampling for landmarking. If True, landmarks
138+
are selected randomly. If False, landmarks are selected deterministically
133139
using spectral clustering.
134140
Defaults to False.
135141
@@ -226,18 +232,18 @@ def __init__(
226232
"Landmarking is disabled when n_landmark=None. "
227233
"To use random landmarking, please set n_landmark to a positive integer "
228234
"(e.g., n_landmark=2000).",
229-
UserWarning
235+
UserWarning,
230236
)
231237
# Disable random_landmarking since it has no effect
232238
random_landmarking = False
233239
# Check graphtools version if random_landmarking is still requested
234-
elif random_landmarking and not _graphtools_supports_random_landmarking():
240+
elif random_landmarking and not _graphtools_version_is_at_least_2_0():
235241
warnings.warn(
236242
"random_landmarking is not available in graphtools version < 2.0.0. "
237243
"Please update graphtools to use this feature: "
238244
"https://pypi.org/project/graphtools/2.0.0/. "
239245
"Falling back to spectral clustering for landmark selection.",
240-
UserWarning
246+
UserWarning,
241247
)
242248
# Disable random_landmarking since it's not supported
243249
random_landmarking = False
@@ -527,9 +533,9 @@ def set_params(self, **params):
527533
If an integer is given, it fixes the seed
528534
Defaults to the global `numpy` random number generator
529535
530-
random_landmarking : bool, optional, default: False
531-
Whether to use random sampling for landmarking. If True, landmarks
532-
are selected randomly. If False, landmarks are selected deterministically
536+
random_landmarking : bool, optional, default: False
537+
Whether to use random sampling for landmarking. If True, landmarks
538+
are selected randomly. If False, landmarks are selected deterministically
533539
using spectral clustering.
534540
Defaults to False.
535541
@@ -814,21 +820,21 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark, random_landmarking):
814820
try:
815821
# Prepare graph params
816822
graph_params = {
817-
'decay': self.decay,
818-
'knn': self.knn,
819-
'knn_max': self.knn_max,
820-
'distance': self.knn_dist,
821-
'precomputed': precomputed,
822-
'n_jobs': self.n_jobs,
823-
'verbose': self.verbose,
824-
'n_pca': n_pca,
825-
'n_landmark': n_landmark,
826-
'random_state': self.random_state,
823+
"decay": self.decay,
824+
"knn": self.knn,
825+
"knn_max": self.knn_max,
826+
"distance": self.knn_dist,
827+
"precomputed": precomputed,
828+
"n_jobs": self.n_jobs,
829+
"verbose": self.verbose,
830+
"n_pca": n_pca,
831+
"n_landmark": n_landmark,
832+
"random_state": self.random_state,
827833
}
828834

829835
# Only add random_landmarking if graphtools supports it
830-
if _graphtools_supports_random_landmarking():
831-
graph_params['random_landmarking'] = random_landmarking
836+
if _graphtools_version_is_at_least_2_0():
837+
graph_params["random_landmarking"] = random_landmarking
832838

833839
self.graph.set_params(**graph_params)
834840
_logger.log_info("Using precomputed graph and diffusion operator...")
@@ -875,36 +881,50 @@ def fit(self, X):
875881
n_landmark = self.n_landmark
876882

877883
if self.graph is not None and update_graph:
878-
self._update_graph(X, precomputed, n_pca, n_landmark, self.random_landmarking)
884+
self._update_graph(
885+
X, precomputed, n_pca, n_landmark, self.random_landmarking
886+
)
879887

880888
self.X = X
881889

882890
if self.graph is None:
883891
with _logger.log_task("graph and diffusion operator"):
884892
# Prepare graph params
885893
graph_params = {
886-
'n_pca': n_pca,
887-
'n_landmark': n_landmark,
888-
'distance': self.knn_dist,
889-
'precomputed': precomputed,
890-
'knn': self.knn,
891-
'knn_max': self.knn_max,
892-
'decay': self.decay,
893-
'thresh': 1e-4,
894-
'n_jobs': self.n_jobs,
895-
'verbose': self.verbose,
896-
'random_state': self.random_state,
894+
"n_pca": n_pca,
895+
"n_landmark": n_landmark,
896+
"distance": self.knn_dist,
897+
"precomputed": precomputed,
898+
"knn": self.knn,
899+
"knn_max": self.knn_max,
900+
"decay": self.decay,
901+
"thresh": 1e-4,
902+
"n_jobs": self.n_jobs,
903+
"verbose": self.verbose,
904+
"random_state": self.random_state,
897905
}
898906

899907
# Only add random_landmarking if graphtools supports it
900-
if _graphtools_supports_random_landmarking():
901-
graph_params['random_landmarking'] = self.random_landmarking
908+
if _graphtools_version_is_at_least_2_0():
909+
graph_params["random_landmarking"] = self.random_landmarking
902910

903911
# Merge with any additional kwargs
904912
graph_params.update(self.kwargs)
905913

906914
self.graph = graphtools.Graph(X, **graph_params)
907915

916+
# Check for graph connectivity (requires graphtools >= 2.0.0)
917+
if _graphtools_version_is_at_least_2_0():
918+
if not self.graph.is_connected:
919+
warnings.warn(
920+
f"Graph is disconnected with {self.graph.n_connected_components} "
921+
f"connected components. This may indicate that your knn parameter "
922+
f"(currently {self.knn}) is too small, or that your data contains "
923+
f"distinct clusters. PHATE may not accurately represent relationships "
924+
f"between disconnected components.",
925+
RuntimeWarning,
926+
)
927+
908928
# landmark op doesn't build unless forced
909929
self.diff_op
910930
return self

Python/phate/sgd_mds.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def sgd_mds(
159159
Y = Y - lr * gradients
160160

161161
# Compute stress for convergence checking
162-
stress = np.sum(errors ** 2) / len(errors) # Normalized by number of samples
162+
stress = np.sum(errors**2) / len(errors) # Normalized by number of samples
163163
stress_history.append(stress)
164164

165165
if verbose > 0 and iteration % 100 == 0:
@@ -189,7 +189,9 @@ def sgd_mds(
189189
last_10pct = max(1, len(stress_history) // 10)
190190
recent_stress = stress_history[-last_10pct:]
191191
if len(recent_stress) > 1:
192-
stress_trend = (recent_stress[-1] - recent_stress[0]) / (recent_stress[0] + 1e-10)
192+
stress_trend = (recent_stress[-1] - recent_stress[0]) / (
193+
recent_stress[0] + 1e-10
194+
)
193195
if abs(stress_trend) > 0.01: # Still changing by more than 1%
194196
_logger.log_warning(
195197
f"SGD-MDS may not have converged: stress changed by {stress_trend*100:.1f}% "

Python/test/test_simple.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python
2+
# author: Daniel Burkhardt <[email protected]>
3+
# (C) 2017 Krishnaswamy Lab GPLv2
4+
5+
# Generating random fractal tree via DLA
6+
from __future__ import print_function, division, absolute_import
7+
8+
import os
9+
import phate
10+
import graphtools
11+
import pygsp
12+
import anndata
13+
import numpy as np
14+
from scipy.spatial.distance import pdist, squareform
15+
16+
import re
17+
18+
import pytest
19+
20+
import warnings
21+
22+
23+
def test_simple():
24+
tree_data, tree_clusters = phate.tree.gen_dla(n_branch=3)
25+
phate_operator = phate.PHATE(knn=15, t=100, verbose=False)
26+
assert isinstance(phate_operator.__str__(), str)
27+
assert isinstance(phate_operator.__repr__(), str)
28+
tree_phate = phate_operator.fit_transform(tree_data)
29+
assert tree_phate.shape == (tree_data.shape[0], 2)
30+
clusters = phate.cluster.kmeans(phate_operator, n_clusters="auto")
31+
assert np.issubdtype(clusters.dtype, np.signedinteger)
32+
assert len(np.unique(clusters)) >= 2
33+
assert len(clusters.shape) == 1
34+
assert len(clusters) == tree_data.shape[0]
35+
clusters = phate.cluster.kmeans(phate_operator, n_clusters=3)
36+
assert np.issubdtype(clusters.dtype, np.signedinteger)
37+
assert len(np.unique(clusters)) == 3
38+
assert len(clusters.shape) == 1
39+
assert len(clusters) == tree_data.shape[0]
40+
phate_operator.fit(phate_operator.graph)
41+
G = graphtools.Graph(
42+
phate_operator.graph.kernel,
43+
precomputed="affinity",
44+
use_pygsp=True,
45+
verbose=False,
46+
)
47+
phate_operator.fit(G)
48+
G = pygsp.graphs.Graph(G.W)
49+
phate_operator.fit(G)
50+
phate_operator.fit(anndata.AnnData(tree_data))
51+
with pytest.raises(TypeError, match="Expected phate_op to be of type PHATE. Got 1"):
52+
phate.cluster.kmeans(1)
53+
54+
55+
def test_vne():
56+
X = np.eye(10)
57+
X[0, 0] = 5
58+
X[3, 2] = 4
59+
h = phate.vne.compute_von_neumann_entropy(X)
60+
assert phate.vne.find_knee_point(h) == 23
61+
x = np.arange(20)
62+
y = np.exp(-x / 10)
63+
assert phate.vne.find_knee_point(y, x) == 8
64+
65+
66+
def test_tree():
67+
# generate DLA tree
68+
M, C = phate.tree.gen_dla(
69+
n_dim=50, n_branch=4, branch_length=50, rand_multiplier=2, seed=37, sigma=4
70+
)
71+
72+
# instantiate phate_operator
73+
phate_operator = phate.PHATE(
74+
n_components=2,
75+
decay=10,
76+
knn=5,
77+
knn_max=15,
78+
t=30,
79+
mds="classic",
80+
knn_dist="euclidean",
81+
mds_dist="euclidean",
82+
n_jobs=-2,
83+
n_landmark=None,
84+
verbose=False,
85+
)
86+
phate_operator.fit(M)
87+
assert phate_operator.graph.knn == 5
88+
assert phate_operator.graph.knn_max == 15
89+
assert phate_operator.graph.decay == 10
90+
assert phate_operator.graph.n_jobs == -2
91+
assert phate_operator.graph.verbose == 0
92+
93+
# run phate with classic MDS
94+
print("DLA tree, classic MDS")
95+
Y_cmds = phate_operator.fit_transform(M)
96+
assert Y_cmds.shape == (M.shape[0], 2)
97+
98+
# run phate with metric MDS
99+
# change the MDS embedding without recalculating diffusion potential
100+
phate_operator.set_params(mds="metric")
101+
print("DLA tree, metric MDS (log)")
102+
Y_mmds = phate_operator.fit_transform(M)
103+
assert Y_mmds.shape == (M.shape[0], 2)
104+
105+
# run phate with nonmetric MDS
106+
phate_operator.set_params(gamma=0)
107+
print("DLA tree, metric MDS (sqrt)")
108+
Y_sqrt = phate_operator.fit_transform(M)
109+
assert Y_sqrt.shape == (M.shape[0], 2)
110+
111+
D = squareform(pdist(M))
112+
K = phate_operator.graph.kernel
113+
phate_operator.set_params(knn_dist="precomputed", random_state=42, verbose=False)
114+
phate_precomputed_D = phate_operator.fit_transform(D)
115+
phate_precomputed_K = phate_operator.fit_transform(K)
116+
117+
phate_operator.set_params(knn_dist="precomputed_distance")
118+
phate_precomputed_distance = phate_operator.fit_transform(D)
119+
120+
phate_operator.set_params(knn_dist="precomputed_affinity")
121+
phate_precomputed_affinity = phate_operator.fit_transform(K)
122+
123+
np.testing.assert_allclose(
124+
phate_precomputed_K, phate_precomputed_affinity, atol=5e-4
125+
)
126+
np.testing.assert_allclose(
127+
phate_precomputed_D, phate_precomputed_distance, atol=5e-4
128+
)
129+
130+
return None
131+
132+
133+
if __name__ == "__main__":
134+
pytest.main([__file__])

0 commit comments

Comments
 (0)