Skip to content

Commit 2d65034

Browse files
committed
added random_landmarking support for precomputed distance/affinity
1 parent 785c0c1 commit 2d65034

File tree

2 files changed

+78
-9
lines changed

2 files changed

+78
-9
lines changed

graphtools/graphs.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,16 +1201,45 @@ def build_landmark_op(self):
12011201
n_samples = self.data.shape[0]
12021202
rng = np.random.default_rng(self.random_state)
12031203
landmark_indices = rng.choice(n_samples, self.n_landmark, replace=False)
1204-
data = (
1205-
self.data if not hasattr(self, "data_nu") else self.data_nu
1206-
) # because of the scaling to review
1207-
if (
1208-
n_samples > 5000 and self.distance == "euclidean"
1209-
): # sklearn.euclidean_distances is faster than cdist for big dataset
1210-
distances = euclidean_distances(data, data[landmark_indices])
1204+
precomputed = getattr(self, "precomputed", None)
1205+
1206+
if precomputed is not None:
1207+
# Use the precomputed affinities/distances directly to avoid Euclidean fallback
1208+
landmark_affinities = self.kernel[:, landmark_indices]
1209+
1210+
if sparse.issparse(landmark_affinities):
1211+
landmark_affinities = landmark_affinities.tocsr()
1212+
cluster_assignments = np.asarray(
1213+
landmark_affinities.argmax(axis=1)
1214+
).reshape(-1)
1215+
row_max = matrix.to_array(
1216+
landmark_affinities.max(axis=1)
1217+
).reshape(-1)
1218+
else:
1219+
landmark_affinities = np.asarray(landmark_affinities)
1220+
cluster_assignments = np.argmax(landmark_affinities, axis=1)
1221+
row_max = np.max(landmark_affinities, axis=1)
1222+
1223+
if np.any(row_max == 0):
1224+
warnings.warn(
1225+
"Some samples have zero affinity to all randomly selected landmarks; "
1226+
"increase n_landmark or ensure the affinity matrix connects all points.",
1227+
RuntimeWarning,
1228+
)
1229+
self._clusters = cluster_assignments
12111230
else:
1212-
distances = cdist(data, data[landmark_indices], metric=self.distance)
1213-
self._clusters = np.argmin(distances, axis=1)
1231+
data = (
1232+
self.data if not hasattr(self, "data_nu") else self.data_nu
1233+
) # because of the scaling to review
1234+
if (
1235+
n_samples > 5000 and self.distance == "euclidean"
1236+
): # sklearn.euclidean_distances is faster than cdist for big dataset
1237+
distances = euclidean_distances(data, data[landmark_indices])
1238+
else:
1239+
distances = cdist(
1240+
data, data[landmark_indices], metric=self.distance
1241+
)
1242+
self._clusters = np.argmin(distances, axis=1)
12141243

12151244
else:
12161245
with _logger.log_task("SVD"):

test/test_random_landmarking.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,46 @@ def test_random_landmarking_distance_parameter_consistency():
405405
assert len(G.clusters) == small_data.shape[0]
406406

407407

408+
def test_random_landmarking_with_precomputed_affinity():
409+
"""Random landmarking should work with precomputed affinity matrices"""
410+
affinity = np.array(
411+
[
412+
[1.0, 0.8, 0.1, 0.0, 0.0, 0.0],
413+
[0.8, 1.0, 0.2, 0.0, 0.0, 0.0],
414+
[0.1, 0.2, 1.0, 0.9, 0.4, 0.0],
415+
[0.0, 0.0, 0.9, 1.0, 0.5, 0.2],
416+
[0.0, 0.0, 0.4, 0.5, 1.0, 0.9],
417+
[0.0, 0.0, 0.0, 0.2, 0.9, 1.0],
418+
]
419+
)
420+
affinity = (affinity + affinity.T) / 2 # ensure symmetry
421+
n_landmark = 3
422+
random_state = 42
423+
424+
G = graphtools.Graph(
425+
affinity,
426+
precomputed="affinity",
427+
n_landmark=n_landmark,
428+
random_landmarking=True,
429+
random_state=random_state,
430+
knn=3,
431+
thresh=0,
432+
)
433+
434+
# Trigger landmark construction
435+
_ = G.landmark_op
436+
437+
rng = np.random.default_rng(random_state)
438+
landmark_indices = rng.choice(affinity.shape[0], n_landmark, replace=False)
439+
expected_clusters = np.asarray(
440+
G.kernel[:, landmark_indices].argmax(axis=1)
441+
).reshape(-1)
442+
443+
assert np.array_equal(G.clusters, expected_clusters)
444+
assert G.transitions.shape == (affinity.shape[0], n_landmark)
445+
assert G.landmark_op.shape == (n_landmark, n_landmark)
446+
447+
408448
#############
409449
# Test API
410450
#############

0 commit comments

Comments
 (0)