Skip to content

Commit d5de920

Browse files
authored
Merge pull request #166 from MattScicluna/add_disconnection_warning
Add disconnection warning
2 parents e920469 + 04591bd commit d5de920

File tree

5 files changed

+66
-42
lines changed

5 files changed

+66
-42
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Python/doc/build/
5353

5454
# Jupyter
5555
.ipynb_checkpoints/
56+
Python/tutorial/cache
5657

5758
# Mac
5859
.DS_Store

Python/phate/mds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def embed_MDS(
218218
# which is much faster than scipy's pdist + squareform
219219
if distance_metric == "euclidean" and X.shape[0] > 1000:
220220
from sklearn.metrics.pairwise import euclidean_distances
221+
221222
X_dist = euclidean_distances(X, X)
222223
else:
223224
X_dist = squareform(pdist(X, distance_metric))
@@ -235,7 +236,7 @@ def embed_MDS(
235236
n_components=ndim,
236237
random_state=seed,
237238
init=Y_classic,
238-
verbose=verbose
239+
verbose=verbose,
239240
)
240241
elif solver == "smacof":
241242
Y = smacof(

Python/phate/phate.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,15 @@
3333

3434
_logger = tasklogger.get_tasklogger("graphtools")
3535

36-
# Check graphtools version for random_landmarking support
37-
def _graphtools_supports_random_landmarking():
38-
"""Check if installed graphtools version supports random_landmarking parameter."""
36+
37+
# Check graphtools version
38+
def _graphtools_version_is_at_least_2_0():
39+
"""Check if installed graphtools version is >= 2.0.0.
40+
41+
Version 2.0.0+ includes support for:
42+
- random_landmarking parameter
43+
- is_connected property and connectivity checks
44+
"""
3945
try:
4046
return version.parse(graphtools.__version__) >= version.parse("2.0.0")
4147
except AttributeError:
@@ -127,9 +133,9 @@ class PHATE(BaseEstimator):
127133
If an integer is given, it fixes the seed
128134
Defaults to the global `numpy` random number generator
129135
130-
random_landmarking : bool, optional, default: False
131-
Whether to use random sampling for landmarking. If True, landmarks
132-
are selected randomly. If False, landmarks are selected deterministically
136+
random_landmarking : bool, optional, default: False
137+
Whether to use random sampling for landmarking. If True, landmarks
138+
are selected randomly. If False, landmarks are selected deterministically
133139
using spectral clustering.
134140
Defaults to False.
135141
@@ -226,18 +232,18 @@ def __init__(
226232
"Landmarking is disabled when n_landmark=None. "
227233
"To use random landmarking, please set n_landmark to a positive integer "
228234
"(e.g., n_landmark=2000).",
229-
UserWarning
235+
UserWarning,
230236
)
231237
# Disable random_landmarking since it has no effect
232238
random_landmarking = False
233239
# Check graphtools version if random_landmarking is still requested
234-
elif random_landmarking and not _graphtools_supports_random_landmarking():
240+
elif random_landmarking and not _graphtools_version_is_at_least_2_0():
235241
warnings.warn(
236242
"random_landmarking is not available in graphtools version < 2.0.0. "
237243
"Please update graphtools to use this feature: "
238244
"https://pypi.org/project/graphtools/2.0.0/. "
239245
"Falling back to spectral clustering for landmark selection.",
240-
UserWarning
246+
UserWarning,
241247
)
242248
# Disable random_landmarking since it's not supported
243249
random_landmarking = False
@@ -527,9 +533,9 @@ def set_params(self, **params):
527533
If an integer is given, it fixes the seed
528534
Defaults to the global `numpy` random number generator
529535
530-
random_landmarking : bool, optional, default: False
531-
Whether to use random sampling for landmarking. If True, landmarks
532-
are selected randomly. If False, landmarks are selected deterministically
536+
random_landmarking : bool, optional, default: False
537+
Whether to use random sampling for landmarking. If True, landmarks
538+
are selected randomly. If False, landmarks are selected deterministically
533539
using spectral clustering.
534540
Defaults to False.
535541
@@ -814,21 +820,21 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark, random_landmarking):
814820
try:
815821
# Prepare graph params
816822
graph_params = {
817-
'decay': self.decay,
818-
'knn': self.knn,
819-
'knn_max': self.knn_max,
820-
'distance': self.knn_dist,
821-
'precomputed': precomputed,
822-
'n_jobs': self.n_jobs,
823-
'verbose': self.verbose,
824-
'n_pca': n_pca,
825-
'n_landmark': n_landmark,
826-
'random_state': self.random_state,
823+
"decay": self.decay,
824+
"knn": self.knn,
825+
"knn_max": self.knn_max,
826+
"distance": self.knn_dist,
827+
"precomputed": precomputed,
828+
"n_jobs": self.n_jobs,
829+
"verbose": self.verbose,
830+
"n_pca": n_pca,
831+
"n_landmark": n_landmark,
832+
"random_state": self.random_state,
827833
}
828834

829835
# Only add random_landmarking if graphtools supports it
830-
if _graphtools_supports_random_landmarking():
831-
graph_params['random_landmarking'] = random_landmarking
836+
if _graphtools_version_is_at_least_2_0():
837+
graph_params["random_landmarking"] = random_landmarking
832838

833839
self.graph.set_params(**graph_params)
834840
_logger.log_info("Using precomputed graph and diffusion operator...")
@@ -875,36 +881,50 @@ def fit(self, X):
875881
n_landmark = self.n_landmark
876882

877883
if self.graph is not None and update_graph:
878-
self._update_graph(X, precomputed, n_pca, n_landmark, self.random_landmarking)
884+
self._update_graph(
885+
X, precomputed, n_pca, n_landmark, self.random_landmarking
886+
)
879887

880888
self.X = X
881889

882890
if self.graph is None:
883891
with _logger.log_task("graph and diffusion operator"):
884892
# Prepare graph params
885893
graph_params = {
886-
'n_pca': n_pca,
887-
'n_landmark': n_landmark,
888-
'distance': self.knn_dist,
889-
'precomputed': precomputed,
890-
'knn': self.knn,
891-
'knn_max': self.knn_max,
892-
'decay': self.decay,
893-
'thresh': 1e-4,
894-
'n_jobs': self.n_jobs,
895-
'verbose': self.verbose,
896-
'random_state': self.random_state,
894+
"n_pca": n_pca,
895+
"n_landmark": n_landmark,
896+
"distance": self.knn_dist,
897+
"precomputed": precomputed,
898+
"knn": self.knn,
899+
"knn_max": self.knn_max,
900+
"decay": self.decay,
901+
"thresh": 1e-4,
902+
"n_jobs": self.n_jobs,
903+
"verbose": self.verbose,
904+
"random_state": self.random_state,
897905
}
898906

899907
# Only add random_landmarking if graphtools supports it
900-
if _graphtools_supports_random_landmarking():
901-
graph_params['random_landmarking'] = self.random_landmarking
908+
if _graphtools_version_is_at_least_2_0():
909+
graph_params["random_landmarking"] = self.random_landmarking
902910

903911
# Merge with any additional kwargs
904912
graph_params.update(self.kwargs)
905913

906914
self.graph = graphtools.Graph(X, **graph_params)
907915

916+
# Check for graph connectivity (requires graphtools >= 2.0.0)
917+
if _graphtools_version_is_at_least_2_0():
918+
if not self.graph.is_connected:
919+
warnings.warn(
920+
f"Graph is disconnected with {self.graph.n_connected_components} "
921+
f"connected components. This may indicate that your knn parameter "
922+
f"(currently {self.knn}) is too small, or that your data contains "
923+
f"distinct clusters. PHATE may not accurately represent relationships "
924+
f"between disconnected components.",
925+
RuntimeWarning,
926+
)
927+
908928
# landmark op doesn't build unless forced
909929
self.diff_op
910930
return self

Python/phate/sgd_mds.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def sgd_mds(
159159
Y = Y - lr * gradients
160160

161161
# Compute stress for convergence checking
162-
stress = np.sum(errors ** 2) / len(errors) # Normalized by number of samples
162+
stress = np.sum(errors**2) / len(errors) # Normalized by number of samples
163163
stress_history.append(stress)
164164

165165
if verbose > 0 and iteration % 100 == 0:
@@ -189,7 +189,9 @@ def sgd_mds(
189189
last_10pct = max(1, len(stress_history) // 10)
190190
recent_stress = stress_history[-last_10pct:]
191191
if len(recent_stress) > 1:
192-
stress_trend = (recent_stress[-1] - recent_stress[0]) / (recent_stress[0] + 1e-10)
192+
stress_trend = (recent_stress[-1] - recent_stress[0]) / (
193+
recent_stress[0] + 1e-10
194+
)
193195
if abs(stress_trend) > 0.01: # Still changing by more than 1%
194196
_logger.log_warning(
195197
f"SGD-MDS may not have converged: stress changed by {stress_trend*100:.1f}% "

Python/test/test_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_tree():
126126
np.testing.assert_allclose(
127127
phate_precomputed_D, phate_precomputed_distance, atol=5e-4
128128
)
129-
129+
130130
return None
131131

132132

0 commit comments

Comments
 (0)