Skip to content

Commit 352e8ff

Browse files
authored
Merge pull request #31 from KrishnaswamyLab/feature/callable_bw
Feature/callable bw
2 parents de4a123 + 30b0cd8 commit 352e8ff

File tree

6 files changed

+320
-63
lines changed

6 files changed

+320
-63
lines changed

graphtools/api.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def Graph(data,
1515
knn=5,
1616
decay=10,
1717
bandwidth=None,
18+
bandwidth_scale=1.0,
1819
anisotropy=0,
1920
distance='euclidean',
2021
thresh=1e-4,
@@ -64,10 +65,14 @@ def Graph(data,
6465
decay : `int` or `None`, optional (default: 10)
6566
Rate of alpha decay to use. If `None`, alpha decay is not used.
6667
67-
bandwidth : `float`, list-like or `None`, optional (default: `None`)
68+
bandwidth : `float`, list-like,`callable`, or `None`, optional (default: `None`)
6869
Fixed bandwidth to use. If given, overrides `knn`. Can be a single
69-
bandwidth or a list-like (shape=[n_samples]) of bandwidths for each
70-
sample.
70+
bandwidth, list-like (shape=[n_samples]) of bandwidths for each
71+
sample, or a `callable` that takes in a `n x m` matrix and returns a
72+
a single value or list-like of length n (shape=[n_samples])
73+
74+
bandwidth_scale : `float`, optional (default : 1.0)
75+
Rescaling factor for bandwidth.
7176
7277
anisotropy : float, optional (default: 0)
7378
Level of anisotropy between 0 and 1
@@ -161,12 +166,18 @@ def Graph(data,
161166
if sample_idx is not None:
162167
# only mnn does batch correction
163168
graphtype = "mnn"
164-
elif precomputed is None and (decay is None or thresh > 0):
169+
elif precomputed is not None:
165170
# precomputed requires exact graph
166-
# no decay or threshold decay require knngraph
171+
graphtype = "exact"
172+
elif decay is None:
173+
# knn kernel
167174
graphtype = "knn"
168-
else:
175+
elif thresh == 0 or callable(bandwidth):
176+
# compute full distance matrix
169177
graphtype = "exact"
178+
else:
179+
# decay kernel with nonzero threshold - knn is more efficient
180+
graphtype = "knn"
170181

171182
# set base graph type
172183
if graphtype == "knn":

graphtools/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.utils.fixes import signature
77
from sklearn.decomposition import PCA, TruncatedSVD
88
from sklearn.preprocessing import normalize
9+
from sklearn.utils.graph import graph_shortest_path
910
from scipy import sparse
1011
import warnings
1112
import numbers
@@ -643,13 +644,13 @@ class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)):
643644
kernel matrix
644645
"""
645646

646-
def __init__(self, gtype='unknown', lap_type='combinatorial', coords=None,
647+
def __init__(self, lap_type='combinatorial', coords=None,
647648
plotting=None, **kwargs):
648649
if plotting is None:
649650
plotting = {}
650651
W = self._build_weight_from_kernel(self.K)
651652

652-
super().__init__(W=W, gtype=gtype,
653+
super().__init__(W=W,
653654
lap_type=lap_type,
654655
coords=coords,
655656
plotting=plotting, **kwargs)

0 commit comments

Comments
 (0)