Skip to content

Commit 76ba4f0

Browse files
committed
make tests pass
1 parent 5104b74 commit 76ba4f0

File tree

7 files changed

+34
-13
lines changed

7 files changed

+34
-13
lines changed

graphtools/graphs.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,25 @@ class kNNGraph(DataGraph):
6464
def __init__(self, data, knn=5, decay=None,
6565
bandwidth=None, distance='euclidean',
6666
thresh=1e-4, n_pca=None, **kwargs):
67-
self.knn = knn
68-
self.decay = decay
69-
self.bandwidth = bandwidth
70-
self.distance = distance
71-
self.thresh = thresh
7267

7368
if decay is not None and thresh <= 0:
7469
raise ValueError("Cannot instantiate a kNNGraph with `decay=None` "
7570
"and `thresh=0`. Use a TraditionalGraph instead.")
7671
if knn > data.shape[0]:
7772
warnings.warn("Cannot set knn ({k}) to be greater than "
78-
"data.shape[0] ({n}). Setting knn={n}".format(
73+
"n_samples ({n}). Setting knn={n}".format(
7974
k=knn, n=data.shape[0]))
75+
knn = data.shape[0]
8076
if n_pca is None and data.shape[1] > 500:
8177
warnings.warn("Building a kNNGraph on data of shape {} is "
8278
"expensive. Consider setting n_pca.".format(
8379
data.shape), UserWarning)
8480

81+
self.knn = knn
82+
self.decay = decay
83+
self.bandwidth = bandwidth
84+
self.distance = distance
85+
self.thresh = thresh
8586
super().__init__(data, n_pca=n_pca, **kwargs)
8687

8788
def get_params(self):
@@ -232,7 +233,7 @@ def build_kernel_to_data(self, Y, knn=None, bandwidth=None):
232233
bandwidth = self.bandwidth
233234
if knn > self.data.shape[0]:
234235
warnings.warn("Cannot set knn ({k}) to be greater than "
235-
"data.shape[0] ({n}). Setting knn={n}".format(
236+
"n_samples ({n}). Setting knn={n}".format(
236237
k=knn, n=self.data.shape[0]))
237238

238239
Y = self._check_extension_shape(Y)
@@ -675,15 +676,20 @@ def __init__(self, data,
675676
n_pca=None,
676677
thresh=1e-4,
677678
precomputed=None, **kwargs):
679+
if decay is None and precomputed not in ['affinity', 'adjacency']:
680+
# decay high enough is basically a binary kernel
681+
raise ValueError("`decay` must be provided for a TraditionalGraph"
682+
". For kNN kernel, use kNNGraph.")
678683
if precomputed is not None and n_pca is not None:
679684
# the data itself is a matrix of distances / affinities
680685
n_pca = None
681686
warnings.warn("n_pca cannot be given on a precomputed graph."
682687
" Setting n_pca=None", RuntimeWarning)
683-
if decay is None and precomputed not in ['affinity', 'adjacency']:
684-
# decay high enough is basically a binary kernel
685-
raise ValueError("`decay` must be provided for a TraditionalGraph"
686-
". For kNN kernel, use kNNGraph.")
688+
if knn > data.shape[0]:
689+
warnings.warn("Cannot set knn ({k}) to be greater than or equal to"
690+
" n_samples ({n}). Setting knn={n}".format(
691+
k=knn, n=data.shape[0] - 1))
692+
knn = data.shape[0] - 1
687693
if precomputed is not None:
688694
if precomputed not in ["distance", "affinity", "adjacency"]:
689695
raise ValueError("Precomputed value {} not recognized. "

test/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import print_function
12
from load_tests import (
23
nose2,
34
data,

test/test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import print_function
12
from load_tests import (
23
np,
34
sp,

test/test_exact.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import print_function
12
from load_tests import (
23
graphtools,
34
np,
@@ -83,6 +84,15 @@ def test_duplicate_data():
8384
thresh=0)
8485

8586

87+
@warns(UserWarning)
88+
def test_k_too_large():
89+
build_graph(data,
90+
n_pca=20,
91+
decay=10,
92+
knn=len(data) + 1,
93+
thresh=0)
94+
95+
8696
#####################################################
8797
# Check kernel
8898
#####################################################

test/test_knn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import print_function
12
from load_tests import (
23
graphtools,
34
np,

test/test_landmark.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import print_function
12
from load_tests import (
23
graphtools,
34
np,
@@ -43,7 +44,7 @@ def test_landmark_exact_graph():
4344
assert(isinstance(G, graphtools.graphs.TraditionalGraph))
4445
assert(isinstance(G, graphtools.graphs.LandmarkGraph))
4546
assert(G.transitions.shape == (data.shape[0], n_landmark))
46-
assert(G.clusters.shape == data.shape[0])
47+
assert(G.clusters.shape == (data.shape[0],))
4748
assert(len(np.unique(G.clusters)) <= n_landmark)
4849
signal = np.random.normal(0, 1, [n_landmark, 10])
4950
interpolated_signal = G.interpolate(signal)
@@ -72,7 +73,7 @@ def test_landmark_mnn_graph():
7273
thresh=1e-5, n_pca=None,
7374
decay=10, knn=5, random_state=42,
7475
sample_idx=sample_idx)
75-
assert(G.clusters.shape == data.shape[0])
76+
assert(G.clusters.shape == (X.shape[0],))
7677
assert(G.landmark_op.shape == (n_landmark, n_landmark))
7778
assert(isinstance(G, graphtools.graphs.MNNGraph))
7879
assert(isinstance(G, graphtools.graphs.LandmarkGraph))

test/test_mnn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import print_function
12
from load_tests import (
23
graphtools,
34
np,

0 commit comments

Comments
 (0)