Skip to content

Commit 85f0e22

Browse files
committed
don't include self for knn
1 parent 292cb42 commit 85f0e22

File tree

5 files changed

+112
-157
lines changed

5 files changed

+112
-157
lines changed

graphtools/graphs.py

Lines changed: 41 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def __init__(self, data, knn=5, decay=None,
8585
if decay is None and bandwidth is not None:
8686
warnings.warn("`bandwidth` is not used when `decay=None`.",
8787
UserWarning)
88-
if knn > data.shape[0]:
88+
if knn > data.shape[0] - 2:
8989
warnings.warn("Cannot set knn ({k}) to be greater than "
9090
"n_samples ({n}). Setting knn={n}".format(
91-
k=knn, n=data.shape[0]))
92-
knn = data.shape[0]
91+
k=knn, n=data.shape[0] - 2))
92+
knn = data.shape[0] - 2
9393
if n_pca is None and data.shape[1] > 500:
9494
warnings.warn("Building a kNNGraph on data of shape {} is "
9595
"expensive. Consider setting n_pca.".format(
@@ -189,7 +189,7 @@ def knn_tree(self):
189189
except AttributeError:
190190
try:
191191
self._knn_tree = NearestNeighbors(
192-
n_neighbors=self.knn,
192+
n_neighbors=self.knn + 1,
193193
algorithm='ball_tree',
194194
metric=self.distance,
195195
n_jobs=self.n_jobs).fit(self.data_nu)
@@ -201,7 +201,7 @@ def knn_tree(self):
201201
self.distance),
202202
UserWarning)
203203
self._knn_tree = NearestNeighbors(
204-
n_neighbors=self.knn,
204+
n_neighbors=self.knn + 1,
205205
algorithm='auto',
206206
metric=self.distance,
207207
n_jobs=self.n_jobs).fit(self.data_nu)
@@ -219,9 +219,35 @@ def build_kernel(self):
219219
symmetric matrix with ones down the diagonal
220220
with no non-negative entries.
221221
"""
222-
K = self.build_kernel_to_data(self.data_nu)
222+
K = self.build_kernel_to_data(self.data_nu, knn=self.knn + 1)
223223
return K
224224

225+
def _check_duplicates(self, distances, indices):
226+
if np.any(distances[:, 1] == 0):
227+
has_duplicates = distances[:, 1] == 0
228+
if np.sum(distances[:, 1:] == 0) < 20:
229+
idx = np.argwhere((distances == 0) &
230+
has_duplicates[:, None])
231+
duplicate_ids = np.array(
232+
[[indices[i[0], i[1]], i[0]]
233+
for i in idx if indices[i[0], i[1]] < i[0]])
234+
duplicate_ids = duplicate_ids[
235+
np.argsort(duplicate_ids[:, 0])]
236+
duplicate_names = ", ".join(["{} and {}".format(i[0], i[1])
237+
for i in duplicate_ids])
238+
warnings.warn(
239+
"Detected zero distance between samples {}. "
240+
"Consider removing duplicates to avoid errors in "
241+
"downstream processing.".format(duplicate_names),
242+
RuntimeWarning)
243+
else:
244+
warnings.warn(
245+
"Detected zero distance between {} pairs of samples. "
246+
"Consider removing duplicates to avoid errors in "
247+
"downstream processing.".format(
248+
np.sum(np.sum(distances[:, 1:]))),
249+
RuntimeWarning)
250+
225251
def build_kernel_to_data(self, Y, knn=None, bandwidth=None,
226252
bandwidth_scale=None):
227253
"""Build a kernel from new input data `Y` to the `self.data`
@@ -281,30 +307,7 @@ def build_kernel_to_data(self, Y, knn=None, bandwidth=None,
281307
search_knn = min(knn * 20, self.data_nu.shape[0])
282308
distances, indices = knn_tree.kneighbors(
283309
Y, n_neighbors=search_knn)
284-
if np.any(distances[:, 1] == 0):
285-
has_duplicates = distances[:, 1] == 0
286-
if np.sum(distances[:, 1:] == 0) < 20:
287-
idx = np.argwhere((distances == 0) &
288-
has_duplicates[:, None])
289-
duplicate_ids = np.array(
290-
[[indices[i[0], i[1]], i[0]]
291-
for i in idx if indices[i[0], i[1]] < i[0]])
292-
duplicate_ids = duplicate_ids[
293-
np.argsort(duplicate_ids[:, 0])]
294-
duplicate_names = ", ".join(["{} and {}".format(i[0], i[1])
295-
for i in duplicate_ids])
296-
warnings.warn(
297-
"Detected zero distance between samples {}. "
298-
"Consider removing duplicates to avoid errors in "
299-
"downstream processing.".format(duplicate_names),
300-
RuntimeWarning)
301-
else:
302-
warnings.warn(
303-
"Detected zero distance between {} pairs of samples. "
304-
"Consider removing duplicates to avoid errors in "
305-
"downstream processing.".format(
306-
np.sum(np.sum(distances[:, 1:]))),
307-
RuntimeWarning)
310+
self._check_duplicates(distances, indices)
308311
tasklogger.log_complete("KNN search")
309312
tasklogger.log_start("affinities")
310313
if bandwidth is None:
@@ -338,7 +341,7 @@ def build_kernel_to_data(self, Y, knn=None, bandwidth=None,
338341
len(update_idx)))
339342
if search_knn > self.data_nu.shape[0] / 2:
340343
knn_tree = NearestNeighbors(
341-
knn, algorithm='brute',
344+
search_knn, algorithm='brute',
342345
n_jobs=self.n_jobs).fit(self.data_nu)
343346
if len(update_idx) > 0:
344347
tasklogger.log_debug(
@@ -771,11 +774,11 @@ def __init__(self, data,
771774
if knn is None and bandwidth is None:
772775
raise ValueError(
773776
"Either `knn` or `bandwidth` must be provided.")
774-
if knn is not None and knn > data.shape[0]:
775-
warnings.warn("Cannot set knn ({k}) to be greater than or equal to"
776-
" n_samples ({n}). Setting knn={n}".format(
777-
k=knn, n=data.shape[0] - 1))
778-
knn = data.shape[0] - 1
777+
if knn is not None and knn > data.shape[0] - 2:
778+
warnings.warn("Cannot set knn ({k}) to be greater than "
779+
" n_samples - 2 ({n}). Setting knn={n}".format(
780+
k=knn, n=data.shape[0] - 2))
781+
knn = data.shape[0] - 2
779782
if precomputed is not None:
780783
if precomputed not in ["distance", "affinity", "adjacency"]:
781784
raise ValueError("Precomputed value {} not recognized. "
@@ -918,7 +921,8 @@ def build_kernel(self):
918921
"Choose from ['affinity', 'adjacency', 'distance', "
919922
"None]".format(self.precomputed))
920923
if self.bandwidth is None:
921-
knn_dist = np.partition(pdx, self.knn, axis=1)[:, :self.knn]
924+
knn_dist = np.partition(
925+
pdx, self.knn + 1, axis=1)[:, :self.knn + 1]
922926
bandwidth = np.max(knn_dist, axis=1)
923927
elif callable(self.bandwidth):
924928
bandwidth = self.bandwidth(pdx)
@@ -1300,8 +1304,6 @@ def build_kernel_to_data(self, Y, theta=None):
13001304
transformation of the landmarks can be trivially applied to `Y` by
13011305
performing
13021306
1303-
TODO: test this.
1304-
13051307
`transform_Y = transitions.dot(transform)`
13061308
13071309
Parameters
@@ -1323,52 +1325,6 @@ def build_kernel_to_data(self, Y, theta=None):
13231325
Transition matrix from `Y` to `self.data`
13241326
"""
13251327
raise NotImplementedError
1326-
tasklogger.log_warning("building MNN kernel to theta is experimental")
1327-
if not isinstance(self.theta, str) and \
1328-
not isinstance(self.theta, numbers.Number):
1329-
if theta is None:
1330-
raise ValueError(
1331-
"self.theta is a matrix but theta is not provided.")
1332-
elif len(theta) != len(self.samples):
1333-
raise ValueError(
1334-
"theta should have one value for every sample")
1335-
1336-
Y = self._check_extension_shape(Y)
1337-
kernel_xy = []
1338-
kernel_yx = []
1339-
# don't really need within Y kernel
1340-
Y_graph = kNNGraph(Y, n_pca=None, knn=0, **(self.knn_args))
1341-
y_knn = self._weight_knn(sample_size=Y.shape[0])
1342-
for i, X in enumerate(self.subgraphs):
1343-
kernel_xy.append(X.build_kernel_to_data(
1344-
Y, knn=self.weighted_knn[i])) # kernel X -> Y
1345-
kernel_yx.append(Y_graph.build_kernel_to_data(
1346-
X.data_nu, knn=y_knn)) # kernel Y -> X
1347-
kernel_xy = sparse.hstack(kernel_xy) # n_cells_y x n_cells_x
1348-
kernel_yx = sparse.vstack(kernel_yx) # n_cells_x x n_cells_y
1349-
1350-
# symmetrize
1351-
if theta is not None:
1352-
# Gamma can be a vector with specific values transitions for
1353-
# each batch. This allows for technical replicates and
1354-
# experimental samples to be corrected simultaneously
1355-
K = np.empty_like(kernel_xy)
1356-
for i, sample in enumerate(self.samples):
1357-
sample_idx = self.sample_idx == sample
1358-
K[:, sample_idx] = theta[i] * \
1359-
kernel_xy[:, sample_idx].minimum(
1360-
kernel_yx[sample_idx, :].T) + \
1361-
(1 - theta[i]) * \
1362-
kernel_xy[:, sample_idx].maximum(
1363-
kernel_yx[sample_idx, :].T)
1364-
if self.theta == "+":
1365-
K = (kernel_xy + kernel_yx.T) / 2
1366-
elif self.theta == "*":
1367-
K = kernel_xy.multiply(kernel_yx.T)
1368-
else:
1369-
K = self.theta * kernel_xy.minimum(kernel_yx.T) + \
1370-
(1 - self.theta) * kernel_xy.maximum(kernel_yx.T)
1371-
return K
13721328

13731329

13741330
class kNNLandmarkGraph(kNNGraph, LandmarkGraph):

test/test_exact.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_k_too_large():
103103
build_graph(data,
104104
n_pca=20,
105105
decay=10,
106-
knn=len(data) + 1,
106+
knn=len(data) - 1,
107107
thresh=0)
108108

109109

@@ -131,7 +131,7 @@ def test_exact_graph():
131131
np.fill_diagonal(W, 0)
132132
G = pygsp.graphs.Graph(W)
133133
G2 = build_graph(data_small, thresh=0, n_pca=n_pca,
134-
decay=a, knn=k, random_state=42,
134+
decay=a, knn=k - 1, random_state=42,
135135
bandwidth_scale=bandwidth_scale,
136136
use_pygsp=True)
137137
assert(G.N == G2.N)
@@ -141,7 +141,7 @@ def test_exact_graph():
141141
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
142142
G2 = build_graph(pdx, n_pca=None, precomputed='distance',
143143
bandwidth_scale=bandwidth_scale,
144-
decay=a, knn=k, random_state=42, use_pygsp=True)
144+
decay=a, knn=k - 1, random_state=42, use_pygsp=True)
145145
assert(G.N == G2.N)
146146
np.testing.assert_equal(G.dw, G2.dw)
147147
assert((G.W != G2.W).nnz == 0)
@@ -195,7 +195,7 @@ def test_truncated_exact_graph():
195195
G2 = build_graph(data_small, thresh=thresh,
196196
graphtype='exact',
197197
n_pca=n_pca,
198-
decay=a, knn=k, random_state=42,
198+
decay=a, knn=k - 1, random_state=42,
199199
use_pygsp=True)
200200
assert(G.N == G2.N)
201201
np.testing.assert_equal(G.dw, G2.dw)
@@ -204,7 +204,7 @@ def test_truncated_exact_graph():
204204
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
205205
G2 = build_graph(pdx, n_pca=None, precomputed='distance',
206206
thresh=thresh,
207-
decay=a, knn=k, random_state=42, use_pygsp=True)
207+
decay=a, knn=k - 1, random_state=42, use_pygsp=True)
208208
assert(G.N == G2.N)
209209
np.testing.assert_equal(G.dw, G2.dw)
210210
assert((G.W != G2.W).nnz == 0)
@@ -252,14 +252,14 @@ def test_truncated_exact_graph_sparse():
252252
G2 = build_graph(sp.coo_matrix(data_small), thresh=thresh,
253253
graphtype='exact',
254254
n_pca=n_pca,
255-
decay=a, knn=k, random_state=42,
255+
decay=a, knn=k - 1, random_state=42,
256256
use_pygsp=True)
257257
assert(G.N == G2.N)
258258
np.testing.assert_allclose(G2.W.toarray(), G.W.toarray())
259259
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
260260
G2 = build_graph(sp.bsr_matrix(pdx), n_pca=None, precomputed='distance',
261261
thresh=thresh,
262-
decay=a, knn=k, random_state=42, use_pygsp=True)
262+
decay=a, knn=k - 1, random_state=42, use_pygsp=True)
263263
assert(G.N == G2.N)
264264
np.testing.assert_equal(G.dw, G2.dw)
265265
assert((G.W != G2.W).nnz == 0)
@@ -304,7 +304,7 @@ def test_truncated_exact_graph_no_pca():
304304
G2 = build_graph(data_small, thresh=thresh,
305305
graphtype='exact',
306306
n_pca=n_pca,
307-
decay=a, knn=k, random_state=42,
307+
decay=a, knn=k - 1, random_state=42,
308308
use_pygsp=True)
309309
assert(G.N == G2.N)
310310
np.testing.assert_equal(G.dw, G2.dw)
@@ -314,7 +314,7 @@ def test_truncated_exact_graph_no_pca():
314314
G2 = build_graph(sp.csr_matrix(data_small), thresh=thresh,
315315
graphtype='exact',
316316
n_pca=n_pca,
317-
decay=a, knn=k, random_state=42,
317+
decay=a, knn=k - 1, random_state=42,
318318
use_pygsp=True)
319319
assert(G.N == G2.N)
320320
np.testing.assert_equal(G.dw, G2.dw)
@@ -379,7 +379,7 @@ def test_exact_graph_callable_bandwidth():
379379
W = np.divide(K, 2)
380380
np.fill_diagonal(W, 0)
381381
G = pygsp.graphs.Graph(W)
382-
G2 = build_graph(data, n_pca=n_pca, knn=knn,
382+
G2 = build_graph(data, n_pca=n_pca, knn=knn - 1,
383383
decay=decay, bandwidth=bandwidth,
384384
random_state=42,
385385
thresh=thresh,
@@ -396,7 +396,7 @@ def test_exact_graph_callable_bandwidth():
396396
W = np.divide(K, 2)
397397
np.fill_diagonal(W, 0)
398398
G = pygsp.graphs.Graph(W)
399-
G2 = build_graph(data, n_pca=n_pca, knn=knn,
399+
G2 = build_graph(data, n_pca=n_pca, knn=knn - 1,
400400
decay=decay, bandwidth=bandwidth,
401401
random_state=42,
402402
thresh=thresh,
@@ -432,7 +432,7 @@ def test_exact_graph_anisotropy():
432432
np.fill_diagonal(W, 0)
433433
G = pygsp.graphs.Graph(W)
434434
G2 = build_graph(data_small, thresh=0, n_pca=n_pca,
435-
decay=a, knn=k, random_state=42,
435+
decay=a, knn=k - 1, random_state=42,
436436
use_pygsp=True, anisotropy=anisotropy)
437437
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
438438
assert(G.N == G2.N)
@@ -441,15 +441,15 @@ def test_exact_graph_anisotropy():
441441
assert((G.W != G2.W).nnz == 0)
442442
assert_raises(ValueError, build_graph,
443443
data_small, thresh=0, n_pca=n_pca,
444-
decay=a, knn=k, random_state=42,
444+
decay=a, knn=k - 1, random_state=42,
445445
use_pygsp=True, anisotropy=-1)
446446
assert_raises(ValueError, build_graph,
447447
data_small, thresh=0, n_pca=n_pca,
448-
decay=a, knn=k, random_state=42,
448+
decay=a, knn=k - 1, random_state=42,
449449
use_pygsp=True, anisotropy=2)
450450
assert_raises(ValueError, build_graph,
451451
data_small, thresh=0, n_pca=n_pca,
452-
decay=a, knn=k, random_state=42,
452+
decay=a, knn=k - 1, random_state=42,
453453
use_pygsp=True, anisotropy='invalid')
454454

455455
#####################################################
@@ -462,32 +462,32 @@ def test_build_dense_exact_kernel_to_data(**kwargs):
462462
n = G.data.shape[0]
463463
K = G.build_kernel_to_data(data[:n // 2, :])
464464
assert(K.shape == (n // 2, n))
465-
K = G.build_kernel_to_data(G.data)
466-
assert(np.sum(G.kernel != (K + K.T) / 2) == 0)
467-
K = G.build_kernel_to_data(G.data_nu)
468-
assert(np.sum(G.kernel != (K + K.T) / 2) == 0)
465+
K = G.build_kernel_to_data(G.data, knn=G.knn + 1)
466+
np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0)
467+
K = G.build_kernel_to_data(G.data_nu, knn=G.knn + 1)
468+
np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0)
469469

470470

471471
def test_build_dense_exact_callable_bw_kernel_to_data(**kwargs):
472472
G = build_graph(data, decay=10, thresh=0, bandwidth=lambda x: x.mean(1))
473473
n = G.data.shape[0]
474474
K = G.build_kernel_to_data(data[:n // 2, :])
475475
assert(K.shape == (n // 2, n))
476-
K = G.build_kernel_to_data(G.data)
477-
assert(np.sum(G.kernel != (K + K.T) / 2) == 0)
478-
K = G.build_kernel_to_data(G.data_nu)
479-
assert(np.sum(G.kernel != (K + K.T) / 2) == 0)
476+
K = G.build_kernel_to_data(G.data, knn=G.knn + 1)
477+
np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0)
478+
K = G.build_kernel_to_data(G.data_nu, knn=G.knn + 1)
479+
np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0)
480480

481481

482482
def test_build_sparse_exact_kernel_to_data(**kwargs):
483483
G = build_graph(data, decay=10, thresh=0, sparse=True)
484484
n = G.data.shape[0]
485485
K = G.build_kernel_to_data(data[:n // 2, :])
486486
assert(K.shape == (n // 2, n))
487-
K = G.build_kernel_to_data(G.data)
488-
assert(np.sum(G.kernel != (K + K.T) / 2) == 0)
489-
K = G.build_kernel_to_data(G.data_nu)
490-
assert(np.sum(G.kernel != (K + K.T) / 2) == 0)
487+
K = G.build_kernel_to_data(G.data, knn=G.knn + 1)
488+
np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0)
489+
K = G.build_kernel_to_data(G.data_nu, knn=G.knn + 1)
490+
np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0)
491491

492492

493493
def test_exact_interpolate():

0 commit comments

Comments
 (0)