|
33 | 33 |
|
34 | 34 | _logger = tasklogger.get_tasklogger("graphtools") |
35 | 35 |
|
36 | | -# Check graphtools version for random_landmarking support |
37 | | -def _graphtools_supports_random_landmarking(): |
38 | | - """Check if installed graphtools version supports random_landmarking parameter.""" |
| 36 | + |
| 37 | +# Check graphtools version |
| 38 | +def _graphtools_version_is_at_least_2_0(): |
| 39 | + """Check if installed graphtools version is >= 2.0.0. |
| 40 | +
|
| 41 | + Version 2.0.0+ includes support for: |
| 42 | + - random_landmarking parameter |
| 43 | + - is_connected property and connectivity checks |
| 44 | + """ |
39 | 45 | try: |
40 | 46 | return version.parse(graphtools.__version__) >= version.parse("2.0.0") |
41 | 47 | except AttributeError: |
@@ -127,9 +133,9 @@ class PHATE(BaseEstimator): |
127 | 133 | If an integer is given, it fixes the seed |
128 | 134 | Defaults to the global `numpy` random number generator |
129 | 135 |
|
130 | | - random_landmarking : bool, optional, default: False |
131 | | - Whether to use random sampling for landmarking. If True, landmarks |
132 | | - are selected randomly. If False, landmarks are selected deterministically |
| 136 | + random_landmarking : bool, optional, default: False |
| 137 | + Whether to use random sampling for landmarking. If True, landmarks |
| 138 | + are selected randomly. If False, landmarks are selected deterministically |
133 | 139 | using spectral clustering. |
134 | 140 | Defaults to False. |
135 | 141 |
|
@@ -226,18 +232,18 @@ def __init__( |
226 | 232 | "Landmarking is disabled when n_landmark=None. " |
227 | 233 | "To use random landmarking, please set n_landmark to a positive integer " |
228 | 234 | "(e.g., n_landmark=2000).", |
229 | | - UserWarning |
| 235 | + UserWarning, |
230 | 236 | ) |
231 | 237 | # Disable random_landmarking since it has no effect |
232 | 238 | random_landmarking = False |
233 | 239 | # Check graphtools version if random_landmarking is still requested |
234 | | - elif random_landmarking and not _graphtools_supports_random_landmarking(): |
| 240 | + elif random_landmarking and not _graphtools_version_is_at_least_2_0(): |
235 | 241 | warnings.warn( |
236 | 242 | "random_landmarking is not available in graphtools version < 2.0.0. " |
237 | 243 | "Please update graphtools to use this feature: " |
238 | 244 | "https://pypi.org/project/graphtools/2.0.0/. " |
239 | 245 | "Falling back to spectral clustering for landmark selection.", |
240 | | - UserWarning |
| 246 | + UserWarning, |
241 | 247 | ) |
242 | 248 | # Disable random_landmarking since it's not supported |
243 | 249 | random_landmarking = False |
@@ -527,9 +533,9 @@ def set_params(self, **params): |
527 | 533 | If an integer is given, it fixes the seed |
528 | 534 | Defaults to the global `numpy` random number generator |
529 | 535 |
|
530 | | - random_landmarking : bool, optional, default: False |
531 | | - Whether to use random sampling for landmarking. If True, landmarks |
532 | | - are selected randomly. If False, landmarks are selected deterministically |
| 536 | + random_landmarking : bool, optional, default: False |
| 537 | + Whether to use random sampling for landmarking. If True, landmarks |
| 538 | + are selected randomly. If False, landmarks are selected deterministically |
533 | 539 | using spectral clustering. |
534 | 540 | Defaults to False. |
535 | 541 |
|
@@ -814,21 +820,21 @@ def _update_graph(self, X, precomputed, n_pca, n_landmark, random_landmarking): |
814 | 820 | try: |
815 | 821 | # Prepare graph params |
816 | 822 | graph_params = { |
817 | | - 'decay': self.decay, |
818 | | - 'knn': self.knn, |
819 | | - 'knn_max': self.knn_max, |
820 | | - 'distance': self.knn_dist, |
821 | | - 'precomputed': precomputed, |
822 | | - 'n_jobs': self.n_jobs, |
823 | | - 'verbose': self.verbose, |
824 | | - 'n_pca': n_pca, |
825 | | - 'n_landmark': n_landmark, |
826 | | - 'random_state': self.random_state, |
| 823 | + "decay": self.decay, |
| 824 | + "knn": self.knn, |
| 825 | + "knn_max": self.knn_max, |
| 826 | + "distance": self.knn_dist, |
| 827 | + "precomputed": precomputed, |
| 828 | + "n_jobs": self.n_jobs, |
| 829 | + "verbose": self.verbose, |
| 830 | + "n_pca": n_pca, |
| 831 | + "n_landmark": n_landmark, |
| 832 | + "random_state": self.random_state, |
827 | 833 | } |
828 | 834 |
|
829 | 835 | # Only add random_landmarking if graphtools supports it |
830 | | - if _graphtools_supports_random_landmarking(): |
831 | | - graph_params['random_landmarking'] = random_landmarking |
| 836 | + if _graphtools_version_is_at_least_2_0(): |
| 837 | + graph_params["random_landmarking"] = random_landmarking |
832 | 838 |
|
833 | 839 | self.graph.set_params(**graph_params) |
834 | 840 | _logger.log_info("Using precomputed graph and diffusion operator...") |
@@ -875,36 +881,50 @@ def fit(self, X): |
875 | 881 | n_landmark = self.n_landmark |
876 | 882 |
|
877 | 883 | if self.graph is not None and update_graph: |
878 | | - self._update_graph(X, precomputed, n_pca, n_landmark, self.random_landmarking) |
| 884 | + self._update_graph( |
| 885 | + X, precomputed, n_pca, n_landmark, self.random_landmarking |
| 886 | + ) |
879 | 887 |
|
880 | 888 | self.X = X |
881 | 889 |
|
882 | 890 | if self.graph is None: |
883 | 891 | with _logger.log_task("graph and diffusion operator"): |
884 | 892 | # Prepare graph params |
885 | 893 | graph_params = { |
886 | | - 'n_pca': n_pca, |
887 | | - 'n_landmark': n_landmark, |
888 | | - 'distance': self.knn_dist, |
889 | | - 'precomputed': precomputed, |
890 | | - 'knn': self.knn, |
891 | | - 'knn_max': self.knn_max, |
892 | | - 'decay': self.decay, |
893 | | - 'thresh': 1e-4, |
894 | | - 'n_jobs': self.n_jobs, |
895 | | - 'verbose': self.verbose, |
896 | | - 'random_state': self.random_state, |
| 894 | + "n_pca": n_pca, |
| 895 | + "n_landmark": n_landmark, |
| 896 | + "distance": self.knn_dist, |
| 897 | + "precomputed": precomputed, |
| 898 | + "knn": self.knn, |
| 899 | + "knn_max": self.knn_max, |
| 900 | + "decay": self.decay, |
| 901 | + "thresh": 1e-4, |
| 902 | + "n_jobs": self.n_jobs, |
| 903 | + "verbose": self.verbose, |
| 904 | + "random_state": self.random_state, |
897 | 905 | } |
898 | 906 |
|
899 | 907 | # Only add random_landmarking if graphtools supports it |
900 | | - if _graphtools_supports_random_landmarking(): |
901 | | - graph_params['random_landmarking'] = self.random_landmarking |
| 908 | + if _graphtools_version_is_at_least_2_0(): |
| 909 | + graph_params["random_landmarking"] = self.random_landmarking |
902 | 910 |
|
903 | 911 | # Merge with any additional kwargs |
904 | 912 | graph_params.update(self.kwargs) |
905 | 913 |
|
906 | 914 | self.graph = graphtools.Graph(X, **graph_params) |
907 | 915 |
|
| 916 | + # Check for graph connectivity (requires graphtools >= 2.0.0) |
| 917 | + if _graphtools_version_is_at_least_2_0(): |
| 918 | + if not self.graph.is_connected: |
| 919 | + warnings.warn( |
| 920 | + f"Graph is disconnected with {self.graph.n_connected_components} " |
| 921 | + f"connected components. This may indicate that your knn parameter " |
| 922 | + f"(currently {self.knn}) is too small, or that your data contains " |
| 923 | + f"distinct clusters. PHATE may not accurately represent relationships " |
| 924 | + f"between disconnected components.", |
| 925 | + RuntimeWarning, |
| 926 | + ) |
| 927 | + |
908 | 928 | # landmark op doesn't build unless forced |
909 | 929 | self.diff_op |
910 | 930 | return self |
|
0 commit comments