Skip to content

Commit 40fefb1

Browse files
authored
Merge pull request #35 from KrishnaswamyLab/feature/to_pkl
adding to_pickle and read_pickle
2 parents 31ee73e + 6abacc7 commit 40fefb1

File tree

5 files changed

+93
-3
lines changed

5 files changed

+93
-3
lines changed

graphtools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .api import Graph, from_igraph
1+
from .api import Graph, from_igraph, read_pickle
22
from .version import __version__

graphtools/api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import warnings
33
import tasklogger
44
from scipy import sparse
5+
import pickle
6+
import pygsp
57

68
from . import base
79
from . import graphs
@@ -283,3 +285,22 @@ def from_igraph(G, attribute="weight", **kwargs):
283285
K = G.get_adjacency(attribute=None).data
284286
return Graph(sparse.coo_matrix(K),
285287
precomputed='adjacency', **kwargs)
288+
289+
290+
def read_pickle(path):
291+
"""Load pickled Graphtools object (or any object) from file.
292+
293+
Parameters
294+
----------
295+
path : str
296+
File path where the pickled object will be loaded.
297+
"""
298+
with open(path, 'rb') as f:
299+
G = pickle.load(f)
300+
301+
if not isinstance(G, base.BaseGraph):
302+
warnings.warn(
303+
'Returning object that is not a graphtools.base.BaseGraph')
304+
elif isinstance(G, base.PyGSPGraph) and isinstance(G.logger, str):
305+
G.logger = pygsp.utils.build_logger(G.logger)
306+
return G

graphtools/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import warnings
1111
import numbers
1212
import tasklogger
13+
import pickle
14+
import sys
1315

1416
try:
1517
import pandas as pd
@@ -635,6 +637,23 @@ def to_igraph(self, attribute="weight", **kwargs):
635637
return ig.Graph.Weighted_Adjacency(utils.to_dense(W).tolist(),
636638
attr=attribute, **kwargs)
637639

640+
def to_pickle(self, path):
641+
"""Save the current Graph to a pickle.
642+
643+
Parameters
644+
----------
645+
path : str
646+
File path where the pickled object will be stored.
647+
"""
648+
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
649+
# python 3.5, 3.6
650+
logger = self.logger
651+
self.logger = logger.name
652+
with open(path, 'wb') as f:
653+
pickle.dump(self, f)
654+
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
655+
self.logger = logger
656+
638657

639658
class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)):
640659
"""Interface between BaseGraph and PyGSP.

graphtools/graphs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import tasklogger
1515

1616
from .utils import (set_diagonal,
17-
elementwise_minimum,
18-
elementwise_maximum,
1917
set_submatrix)
2018
from .base import DataGraph, PyGSPGraph
2119

test/test_api.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import igraph
1010
import numpy as np
1111
import graphtools
12+
import tempfile
13+
import os
1214

1315

1416
def test_from_igraph():
@@ -81,6 +83,56 @@ def test_to_igraph():
8183
attribute="weight").data) == G.W)
8284

8385

86+
def test_pickle_io_knngraph():
87+
G = build_graph(data, knn=5, decay=None)
88+
with tempfile.TemporaryDirectory() as tempdir:
89+
path = os.path.join(tempdir, 'tmp.pkl')
90+
G.to_pickle(path)
91+
G_prime = graphtools.read_pickle(path)
92+
assert isinstance(G_prime, type(G))
93+
94+
95+
def test_pickle_io_traditionalgraph():
96+
G = build_graph(data, knn=5, decay=10, thresh=0)
97+
with tempfile.TemporaryDirectory() as tempdir:
98+
path = os.path.join(tempdir, 'tmp.pkl')
99+
G.to_pickle(path)
100+
G_prime = graphtools.read_pickle(path)
101+
assert isinstance(G_prime, type(G))
102+
103+
104+
def test_pickle_io_landmarkgraph():
105+
G = build_graph(data, knn=5, decay=None,
106+
n_landmark=data.shape[0] // 2)
107+
L = G.landmark_op
108+
with tempfile.TemporaryDirectory() as tempdir:
109+
path = os.path.join(tempdir, 'tmp.pkl')
110+
G.to_pickle(path)
111+
G_prime = graphtools.read_pickle(path)
112+
assert isinstance(G_prime, type(G))
113+
np.testing.assert_array_equal(L, G_prime._landmark_op)
114+
115+
116+
def test_pickle_io_pygspgraph():
117+
G = build_graph(data, knn=5, decay=None, use_pygsp=True)
118+
with tempfile.TemporaryDirectory() as tempdir:
119+
path = os.path.join(tempdir, 'tmp.pkl')
120+
G.to_pickle(path)
121+
G_prime = graphtools.read_pickle(path)
122+
assert isinstance(G_prime, type(G))
123+
assert G_prime.logger.name == G.logger.name
124+
125+
126+
@warns(UserWarning)
127+
def test_pickle_bad_pickle():
128+
import pickle
129+
with tempfile.TemporaryDirectory() as tempdir:
130+
path = os.path.join(tempdir, 'tmp.pkl')
131+
with open(path, 'wb') as f:
132+
pickle.dump('hello world', f)
133+
G = graphtools.read_pickle(path)
134+
135+
84136
@warns(UserWarning)
85137
def test_to_pygsp_invalid_precomputed():
86138
G = build_graph(data)

0 commit comments

Comments
 (0)