Skip to content

Commit d7220a8

Browse files
committed
fixed small bug with graphs (so random_landmark graphs keep all attributes). Added random_landmarking attribute
1 parent 81eda46 commit d7220a8

File tree

3 files changed

+28
-31
lines changed

3 files changed

+28
-31
lines changed

graphtools/api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def Graph(
3131
adaptive_k=None,
3232
n_landmark=None,
3333
n_svd=100,
34+
random_landmarking=False,
3435
n_jobs=-1,
3536
verbose=False,
3637
random_state=None,
@@ -141,6 +142,11 @@ def Graph(
141142
n_svd : `int`, optional (default: 100)
142143
number of SVD components to use for spectral clustering
143144
145+
random_landmarking : `bool`, optional (default: False)
146+
If True, use random landmark selection instead of spectral clustering.
147+
Randomly selects n_landmark points and assigns samples to nearest landmark.
148+
Only used when n_landmark is not None.
149+
144150
random_state : `int` or `None`, optional (default: `None`)
145151
Random state for random PCA
146152

graphtools/graphs.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ def decorator(func):
4141
@njit(parallel=True)
4242
def _numba_compute_kernel_matrix(distances, indices, bandwidth, decay, thresh):
4343
"""
44-
Advanced PHATE-inspired numba kernel computation.
45-
4644
Key optimizations:
4745
- Uses float32 for memory efficiency
4846
- Vectorized operations with in-place modifications
@@ -117,8 +115,6 @@ def _numba_compute_kernel_matrix(distances, indices, bandwidth, decay, thresh):
117115
@njit(parallel=True)
118116
def _numba_process_kernel_data_vectorized(distances, bandwidth, decay, thresh):
119117
"""
120-
PHATE-inspired vectorized kernel computation with advanced optimizations.
121-
122118
Key improvements:
123119
- In-place vectorized operations like PHATE
124120
- float32 precision for memory efficiency
@@ -165,8 +161,6 @@ def _numba_process_kernel_data_vectorized(distances, bandwidth, decay, thresh):
165161
@njit(parallel=True)
166162
def _numba_build_csr_components(data, indices, valid_mask, n_rows, n_cols):
167163
"""
168-
PHATE-inspired efficient CSR matrix construction.
169-
170164
Optimizations:
171165
- Parallel row counting
172166
- Efficient memory pre-allocation
@@ -209,9 +203,7 @@ def _numba_build_csr_components(data, indices, valid_mask, n_rows, n_cols):
209203
@njit(parallel=True)
210204
def _numba_build_kernel_to_data_optimized(pdx, bandwidth, decay, thresh):
211205
"""
212-
PHATE-inspired optimized kernel-to-data computation with numba.
213-
214-
This function implements the core optimizations from PHATE benchmarks:
206+
This function implements optimizations:
215207
- float32 precision for memory efficiency
216208
- Vectorized in-place operations
217209
- Parallel processing
@@ -871,7 +863,7 @@ def __init__(self, data, n_landmark=2000, n_svd=100, random_landmarking=False, *
871863
"n_landmark ({}) >= n_samples ({}). Use "
872864
"kNNGraph instead".format(n_landmark, data.shape[0])
873865
)
874-
if n_svd >= data.shape[0]:
866+
if (n_svd >= data.shape[0]) and (not random_landmarking):
875867
warnings.warn(
876868
"n_svd ({}) >= n_samples ({}) Consider "
877869
"using kNNGraph or lower n_svd".format(n_svd, data.shape[0]),
@@ -885,7 +877,7 @@ def __init__(self, data, n_landmark=2000, n_svd=100, random_landmarking=False, *
885877
def get_params(self):
886878
"""Get parameters from this object"""
887879
params = super().get_params()
888-
params.update({"n_landmark": self.n_landmark, "n_pca": self.n_pca})
880+
params.update({"n_landmark": self.n_landmark, "n_pca": self.n_pca, "random_landmarking": self.random_landmarking})
889881
return params
890882

891883
def set_params(self, **params):
@@ -896,6 +888,7 @@ def set_params(self, **params):
896888
Valid parameters:
897889
- n_landmark
898890
- n_svd
891+
- random_landmarks
899892
900893
Parameters
901894
----------
@@ -913,6 +906,9 @@ def set_params(self, **params):
913906
if "n_svd" in params and params["n_svd"] != self.n_svd:
914907
self.n_svd = params["n_svd"]
915908
reset_landmarks = True
909+
if "random_landmarking" in params and params["random_landmarking"] != self.random_landmarking:
910+
self.random_landmarking = params["random_landmarking"]
911+
reset_landmarks = True
916912
# update superclass parameters
917913
super().set_params(**params)
918914
# reset things that changed
@@ -1007,7 +1003,6 @@ def _data_transitions(self):
10071003
def build_landmark_op(self):
10081004
"""Build the landmark operator
10091005
1010-
10111006
Calculates spectral clusters on the kernel, and calculates transition
10121007
probabilities between cluster centers by using transition probabilities
10131008
between samples assigned to each cluster.
@@ -1016,7 +1011,6 @@ def build_landmark_op(self):
10161011
This method randomly selects n_landmark points and assigns each sample to its nearest landmark
10171012
using Euclidean distance .
10181013
1019-
10201014
"""
10211015
if self.random_landmarking :
10221016
with _logger.log_task("landmark operator"):
@@ -1025,7 +1019,6 @@ def build_landmark_op(self):
10251019
rng = np.random.default_rng(self.random_state)
10261020
landmark_indices = rng.choice(n_samples, self.n_landmark, replace=False)
10271021
data = self.data if not hasattr(self, 'data_nu') else self.data_nu # because of the scaling to review
1028-
distances = cdist(data, data[landmark_indices], metric="euclidean")
10291022
if n_samples > 5000: # sklearn.euclidean_distances is faster than cdist for big dataset
10301023
distances = euclidean_distances(data, data[landmark_indices])
10311024
else:
@@ -1052,23 +1045,21 @@ def build_landmark_op(self):
10521045
)
10531046
self._clusters = kmeans.fit_predict(self.diff_op.dot(VT.T))
10541047

1048+
# transition matrices
1049+
pmn = self._landmarks_to_data()
10551050

1056-
1057-
# transition matrices
1058-
pmn = self._landmarks_to_data()
1059-
1060-
# row normalize
1061-
pnm = pmn.transpose()
1062-
pmn = normalize(pmn, norm="l1", axis=1)
1063-
pnm = normalize(pnm, norm="l1", axis=1)
1064-
# sparsity agnostic matrix multiplication
1065-
landmark_op = pmn.dot(pnm)
1066-
if is_sparse:
1067-
# no need to have a sparse landmark operator
1068-
landmark_op = landmark_op.toarray()
1069-
# store output
1070-
self._landmark_op = landmark_op
1071-
self._transitions = pnm
1051+
# row normalize
1052+
pnm = pmn.transpose()
1053+
pmn = normalize(pmn, norm="l1", axis=1)
1054+
pnm = normalize(pnm, norm="l1", axis=1)
1055+
# sparsity agnostic matrix multiplication
1056+
landmark_op = pmn.dot(pnm)
1057+
if is_sparse:
1058+
# no need to have a sparse landmark operator
1059+
landmark_op = landmark_op.toarray()
1060+
# store output
1061+
self._landmark_op = landmark_op
1062+
self._transitions = pnm
10721063

10731064
def extend_to_data(self, data, **kwargs):
10741065
"""Build transition matrix from new data to the graph

test/load_tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def build_graph(
124124
n_pca=n_pca,
125125
decay=decay,
126126
knn=knn,
127-
random_state=42,
127+
random_state=random_state,
128128
verbose=verbose,
129129
**kwargs,
130130
)

0 commit comments

Comments
 (0)