Skip to content

Commit 3edc962

Browse files
authored
Merge pull request #36 from KrishnaswamyLab/dev
graphtools v1.1
2 parents 29b6f90 + 40fefb1 commit 3edc962

File tree

7 files changed

+107
-12
lines changed

7 files changed

+107
-12
lines changed

graphtools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .api import Graph, from_igraph
1+
from .api import Graph, from_igraph, read_pickle
22
from .version import __version__

graphtools/api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import warnings
33
import tasklogger
44
from scipy import sparse
5+
import pickle
6+
import pygsp
57

68
from . import base
79
from . import graphs
@@ -283,3 +285,22 @@ def from_igraph(G, attribute="weight", **kwargs):
283285
K = G.get_adjacency(attribute=None).data
284286
return Graph(sparse.coo_matrix(K),
285287
precomputed='adjacency', **kwargs)
288+
289+
290+
def read_pickle(path):
291+
"""Load pickled Graphtools object (or any object) from file.
292+
293+
Parameters
294+
----------
295+
path : str
296+
File path where the pickled object will be loaded.
297+
"""
298+
with open(path, 'rb') as f:
299+
G = pickle.load(f)
300+
301+
if not isinstance(G, base.BaseGraph):
302+
warnings.warn(
303+
'Returning object that is not a graphtools.base.BaseGraph')
304+
elif isinstance(G, base.PyGSPGraph) and isinstance(G.logger, str):
305+
G.logger = pygsp.utils.build_logger(G.logger)
306+
return G

graphtools/base.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import numpy as np
44
import abc
55
import pygsp
6-
from sklearn.utils.fixes import signature
6+
from inspect import signature
77
from sklearn.decomposition import PCA, TruncatedSVD
88
from sklearn.preprocessing import normalize
9-
from sklearn.utils.graph import graph_shortest_path
109
from scipy import sparse
1110
import warnings
1211
import numbers
1312
import tasklogger
13+
import pickle
14+
import sys
1415

1516
try:
1617
import pandas as pd
@@ -106,10 +107,10 @@ class Data(Base):
106107
def __init__(self, data, n_pca=None, random_state=None, **kwargs):
107108

108109
self._check_data(data)
109-
if n_pca is not None and data.shape[1] <= n_pca:
110+
if n_pca is not None and np.min(data.shape) <= n_pca:
110111
warnings.warn("Cannot perform PCA to {} dimensions on "
111-
"data with {} dimensions".format(n_pca,
112-
data.shape[1]),
112+
"data with min(n_samples, n_features) = {}".format(
113+
n_pca, np.min(data.shape)),
113114
RuntimeWarning)
114115
n_pca = None
115116
try:
@@ -316,7 +317,7 @@ class BaseGraph(with_metaclass(abc.ABCMeta, Base)):
316317
'theta' : min-max
317318
'none' : no symmetrization
318319
319-
theta: float (default: 0.5)
320+
theta: float (default: 1)
320321
Min-max symmetrization constant.
321322
K = `theta * min(K, K.T) + (1 - theta) * max(K, K.T)`
322323
@@ -385,7 +386,7 @@ def _check_symmetrization(self, kernel_symm, theta):
385386
if theta is None:
386387
warnings.warn("kernel_symm='theta' but theta not given. "
387388
"Defaulting to theta=0.5.")
388-
self.theta = theta = 0.5
389+
self.theta = theta = 1
389390
elif not isinstance(theta, numbers.Number) or \
390391
theta < 0 or theta > 1:
391392
raise ValueError("theta {} not recognized. Expected "
@@ -636,6 +637,23 @@ def to_igraph(self, attribute="weight", **kwargs):
636637
return ig.Graph.Weighted_Adjacency(utils.to_dense(W).tolist(),
637638
attr=attribute, **kwargs)
638639

640+
def to_pickle(self, path):
641+
"""Save the current Graph to a pickle.
642+
643+
Parameters
644+
----------
645+
path : str
646+
File path where the pickled object will be stored.
647+
"""
648+
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
649+
# python 3.5, 3.6
650+
logger = self.logger
651+
self.logger = logger.name
652+
with open(path, 'wb') as f:
653+
pickle.dump(self, f)
654+
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
655+
self.logger = logger
656+
639657

640658
class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)):
641659
"""Interface between BaseGraph and PyGSP.

graphtools/graphs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import tasklogger
1515

1616
from .utils import (set_diagonal,
17-
elementwise_minimum,
18-
elementwise_maximum,
1917
set_submatrix)
2018
from .base import DataGraph, PyGSPGraph
2119

@@ -245,7 +243,7 @@ def _check_duplicates(self, distances, indices):
245243
"Detected zero distance between {} pairs of samples. "
246244
"Consider removing duplicates to avoid errors in "
247245
"downstream processing.".format(
248-
np.sum(np.sum(distances[:, 1:]))),
246+
np.sum(np.sum(distances[:, 1:] == 0))),
249247
RuntimeWarning)
250248

251249
def build_kernel_to_data(self, Y, knn=None, bandwidth=None,

graphtools/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0"
1+
__version__ = "1.1.0"

test/test_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import igraph
1010
import numpy as np
1111
import graphtools
12+
import tempfile
13+
import os
1214

1315

1416
def test_from_igraph():
@@ -81,6 +83,56 @@ def test_to_igraph():
8183
attribute="weight").data) == G.W)
8284

8385

86+
def test_pickle_io_knngraph():
87+
G = build_graph(data, knn=5, decay=None)
88+
with tempfile.TemporaryDirectory() as tempdir:
89+
path = os.path.join(tempdir, 'tmp.pkl')
90+
G.to_pickle(path)
91+
G_prime = graphtools.read_pickle(path)
92+
assert isinstance(G_prime, type(G))
93+
94+
95+
def test_pickle_io_traditionalgraph():
96+
G = build_graph(data, knn=5, decay=10, thresh=0)
97+
with tempfile.TemporaryDirectory() as tempdir:
98+
path = os.path.join(tempdir, 'tmp.pkl')
99+
G.to_pickle(path)
100+
G_prime = graphtools.read_pickle(path)
101+
assert isinstance(G_prime, type(G))
102+
103+
104+
def test_pickle_io_landmarkgraph():
105+
G = build_graph(data, knn=5, decay=None,
106+
n_landmark=data.shape[0] // 2)
107+
L = G.landmark_op
108+
with tempfile.TemporaryDirectory() as tempdir:
109+
path = os.path.join(tempdir, 'tmp.pkl')
110+
G.to_pickle(path)
111+
G_prime = graphtools.read_pickle(path)
112+
assert isinstance(G_prime, type(G))
113+
np.testing.assert_array_equal(L, G_prime._landmark_op)
114+
115+
116+
def test_pickle_io_pygspgraph():
117+
G = build_graph(data, knn=5, decay=None, use_pygsp=True)
118+
with tempfile.TemporaryDirectory() as tempdir:
119+
path = os.path.join(tempdir, 'tmp.pkl')
120+
G.to_pickle(path)
121+
G_prime = graphtools.read_pickle(path)
122+
assert isinstance(G_prime, type(G))
123+
assert G_prime.logger.name == G.logger.name
124+
125+
126+
@warns(UserWarning)
127+
def test_pickle_bad_pickle():
128+
import pickle
129+
with tempfile.TemporaryDirectory() as tempdir:
130+
path = os.path.join(tempdir, 'tmp.pkl')
131+
with open(path, 'wb') as f:
132+
pickle.dump('hello world', f)
133+
G = graphtools.read_pickle(path)
134+
135+
84136
@warns(UserWarning)
85137
def test_to_pygsp_invalid_precomputed():
86138
G = build_graph(data)

test/test_data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def test_too_many_n_pca():
4444
build_graph(data, n_pca=data.shape[1])
4545

4646

47+
@warns(RuntimeWarning)
48+
def test_too_many_n_pca():
49+
build_graph(data[:data.shape[1] - 1],
50+
n_pca=data.shape[1] - 1)
51+
52+
4753
@warns(RuntimeWarning)
4854
def test_precomputed_with_pca():
4955
build_graph(squareform(pdist(data)),

0 commit comments

Comments
 (0)