Skip to content

Commit 2787202

Browse files
authored
Merge pull request #70 from sedfanne/random-landmarking
Random landmarking
2 parents 26e7c9f + cbc8768 commit 2787202

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

.github/workflows/run_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
fail-fast: false
2323
matrix:
2424
config:
25+
- {name: '3.11', os: ubuntu-latest, python: '3.11' }
2526
- {name: '3.10', os: ubuntu-latest, python: '3.10' }
2627
- {name: '3.9', os: ubuntu-latest, python: '3.9' }
2728
- {name: '3.8', os: ubuntu-latest, python: '3.8' }

graphtools/graphs.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from sklearn.neighbors import NearestNeighbors
1414
from sklearn.preprocessing import normalize
1515
from sklearn.utils.extmath import randomized_svd
16+
from sklearn.metrics.pairwise import euclidean_distances
17+
1618

1719
import numbers
1820
import numpy as np
@@ -82,7 +84,6 @@ def __init__(
8284
n_pca=None,
8385
**kwargs,
8486
):
85-
8687
if decay is not None:
8788
if thresh <= 0 and knn_max is None:
8889
raise ValueError(
@@ -489,7 +490,9 @@ class LandmarkGraph(DataGraph):
489490
>>> X_full = G.interpolate(X_landmark)
490491
"""
491492

492-
def __init__(self, data, n_landmark=2000, n_svd=100, **kwargs):
493+
def __init__(
494+
self, data, n_landmark=2000, n_svd=100, random_landmarking=False, **kwargs
495+
):
493496
"""Initialize a landmark graph.
494497
495498
Raises
@@ -508,6 +511,7 @@ def __init__(self, data, n_landmark=2000, n_svd=100, **kwargs):
508511
"using kNNGraph or lower n_svd".format(n_svd, data.shape[0]),
509512
RuntimeWarning,
510513
)
514+
self.random_landmarking = random_landmarking
511515
self.n_landmark = n_landmark
512516
self.n_svd = n_svd
513517
super().__init__(data, **kwargs)
@@ -637,28 +641,48 @@ def _data_transitions(self):
637641
def build_landmark_op(self):
638642
"""Build the landmark operator
639643
644+
640645
Calculates spectral clusters on the kernel, and calculates transition
641646
probabilities between cluster centers by using transition probabilities
642647
between samples assigned to each cluster.
648+
649+
random_landmarking:
650+
This method randomly selects n_landmark points and assigns each sample to its nearest landmark
651+
using Euclidean distance .
652+
653+
643654
"""
644655
with _logger.log_task("landmark operator"):
645656
is_sparse = sparse.issparse(self.kernel)
646-
# spectral clustering
647-
with _logger.log_task("SVD"):
648-
_, _, VT = randomized_svd(
649-
self.diff_aff,
650-
n_components=self.n_svd,
651-
random_state=self.random_state,
652-
)
653-
with _logger.log_task("KMeans"):
654-
kmeans = MiniBatchKMeans(
655-
self.n_landmark,
656-
init_size=3 * self.n_landmark,
657-
n_init=1,
658-
batch_size=10000,
659-
random_state=self.random_state,
660-
)
661-
self._clusters = kmeans.fit_predict(self.diff_op.dot(VT.T))
657+
658+
if self.random_landmarking:
659+
n_samples = self.data.shape[0]
660+
rng = np.random.default_rng(self.random_state)
661+
landmark_indices = rng.choice(n_samples, self.n_landmark, replace=False)
662+
data = self.data if not hasattr(self, "data_nu") else self.data_nu
663+
# if n_samples > 5000 and self.distance == "euclidean": ( sklearn.euclidean_distances is faster than cdist for big dataset)
664+
# distances = euclidean_distances(data, data[landmark_indices])
665+
# this is a futur optimization for the euclidean case
666+
#
667+
distances = cdist(data, data[landmark_indices], metric=self.distance)
668+
self._clusters = np.argmin(distances, axis=1)
669+
670+
else:
671+
with _logger.log_task("SVD"):
672+
_, _, VT = randomized_svd(
673+
self.diff_aff,
674+
n_components=self.n_svd,
675+
random_state=self.random_state,
676+
)
677+
with _logger.log_task("KMeans"):
678+
kmeans = MiniBatchKMeans(
679+
self.n_landmark,
680+
init_size=3 * self.n_landmark,
681+
n_init=1,
682+
batch_size=10000,
683+
random_state=self.random_state,
684+
)
685+
self._clusters = kmeans.fit_predict(self.diff_op.dot(VT.T))
662686

663687
# transition matrices
664688
pmn = self._landmarks_to_data()

test/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_anndata_input():
101101
E2 = Estimator(verbose=0)
102102
E2.fit(anndata.AnnData(X))
103103
np.testing.assert_allclose(
104-
E.graph.K.toarray(), E2.graph.K.toarray(), rtol=1e-6, atol=2e-7
104+
E.graph.K.toarray(), E2.graph.K.toarray(), rtol=1e-6, atol=1e-6
105105
)
106106

107107

0 commit comments

Comments
 (0)