Skip to content

Commit 30b0cd8

Browse files
committed
callable bandwidth should operate on the rows of the distance matrix
1 parent 4e0b970 commit 30b0cd8

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

graphtools/api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@ def Graph(data,
6565
decay : `int` or `None`, optional (default: 10)
6666
Rate of alpha decay to use. If `None`, alpha decay is not used.
6767
68-
bandwidth : `float`, list-like or `None`, optional (default: `None`)
68+
bandwidth : `float`, list-like,`callable`, or `None`, optional (default: `None`)
6969
Fixed bandwidth to use. If given, overrides `knn`. Can be a single
70-
bandwidth or a list-like (shape=[n_samples]) of bandwidths for each
71-
sample.
70+
bandwidth, list-like (shape=[n_samples]) of bandwidths for each
71+
sample, or a `callable` that takes in a `n x m` matrix and returns a
72+
a single value or list-like of length n (shape=[n_samples])
7273
7374
bandwidth_scale : `float`, optional (default : 1.0)
7475
Rescaling factor for bandwidth.

graphtools/graphs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,8 @@ class TraditionalGraph(DataGraph):
721721
bandwidth : `float`, list-like,`callable`, or `None`, optional (default: `None`)
722722
Fixed bandwidth to use. If given, overrides `knn`. Can be a single
723723
bandwidth, list-like (shape=[n_samples]) of bandwidths for each
724-
sample, or a `callable` that takes in a square matrix and returns a
725-
a single value or list-like(shape=[n_samples])
724+
sample, or a `callable` that takes in a `n x m` matrix and returns a
725+
a single value or list-like of length n (shape=[n_samples])
726726
727727
bandwidth_scale : `float`, optional (default : 1.0)
728728
Rescaling factor for bandwidth.

test/test_exact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def test_build_dense_exact_kernel_to_data(**kwargs):
469469

470470

471471
def test_build_dense_exact_callable_bw_kernel_to_data(**kwargs):
472-
G = build_graph(data, decay=10, thresh=0, bandwidth=lambda x: x.mean(0))
472+
G = build_graph(data, decay=10, thresh=0, bandwidth=lambda x: x.mean(1))
473473
n = G.data.shape[0]
474474
K = G.build_kernel_to_data(data[:n // 2, :])
475475
assert(K.shape == (n // 2, n))

0 commit comments

Comments
 (0)