Skip to content

Commit b57428e

Browse files
committed
to_gephi
1 parent 0220639 commit b57428e

File tree

2 files changed

+60
-24
lines changed

2 files changed

+60
-24
lines changed

csnanalysis/csn.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,7 @@
11
import scipy
22
import networkx as nx
33
import numpy as np
4-
from csnanalysis.matrix import eig_weights, mult_weights, committor
5-
6-
def count_to_trans(countmat):
7-
"""
8-
Converts a count matrix (in scipy sparse format) to a transition
9-
matrix.
10-
"""
11-
tmp = np.array(countmat.toarray(),dtype=float)
12-
colsums = tmp.sum(axis=0)
13-
for i,c in enumerate(colsums):
14-
if c > 0:
15-
tmp[:,i] /= c
16-
17-
return(scipy.sparse.coo_matrix(tmp))
18-
19-
def symmetrize(countmat):
20-
"""
21-
Symmetrizes a count matrix (in scipy sparse format).
22-
"""
23-
return 0.5*(countmat + countmat.transpose())
4+
from csnanalysis.matrix import *
245

256
class CSN(object):
267

@@ -55,15 +36,51 @@ def __init__(self, counts, symmetrize=False):
5536

5637
# initialize networkX directed graph
5738
self.graph = nx.DiGraph()
58-
labels = [{'ID' : i} for i in range(self.nnodes)]
39+
labels = [{'label' : i, 'ID' : i} for i in range(self.nnodes)]
5940
self.graph.add_nodes_from(zip(range(self.nnodes),labels))
6041
self.graph.add_weighted_edges_from(zip(self.transmat.col,self.transmat.row,self.transmat.data))
6142

62-
def to_gephi(self, cols='all', node_name='node.csv', edge_name='edge.csv'):
43+
def to_gephi(self, cols='all', node_name='node.csv', edge_name='edge.csv', directed=False):
6344
"""
6445
Writes node and edge files for import into the Gephi network visualization program.
46+
47+
cols -- A list of columns that should be written to the node file. ID and label are
48+
included by default. 'all' will include every attribute attached to the
49+
nodes in self.graph.
50+
6551
"""
52+
if cols == 'all':
53+
cols = list(self.graph.node[0].keys())
54+
else:
55+
if 'label' not in cols:
56+
cols = ['label'] + cols
57+
if 'ID' not in cols:
58+
cols = ['ID'] + cols
6659

60+
with open(node_name,mode='w') as f:
61+
f.write(" ".join(cols)+"\n")
62+
for i in range(self.nnodes):
63+
data = [str(self.graph.node[i][c]) for c in cols]
64+
f.write(' '.join(data)+"\n")
65+
66+
# compute edge weights
67+
if directed:
68+
with open(edge_name,mode='w') as f:
69+
f.write("source target type prob i_weight\n")
70+
for from_ind,edge_dict in self.graph.edge.items():
71+
for to_ind,edge in edge_dict.items():
72+
f.write("{0:d} {1:d} {2:s} {3:f} {4:d}\n".format(from_ind,to_ind,'Directed',edge['weight'],int(edge['weight']*100)))
73+
else:
74+
with open(edge_name,mode='w') as f:
75+
f.write("source target type prob i_weight\n")
76+
for from_ind,edge_dict in self.graph.edge.items():
77+
for to_ind,edge in edge_dict.items():
78+
if from_ind <= to_ind:
79+
if to_ind in self.graph.edge and from_ind in self.graph.edge[to_ind]:
80+
edge_weight = 0.5*(self.graph.edge[to_ind][from_ind]['weight'] + edge['weight'])
81+
else:
82+
edge_weight = 0.5*edge['weight']
83+
f.write("{0:d} {1:d} {2:s} {3:f} {4:d}\n".format(from_ind,to_ind,'Undirected',edge_weight,int(edge_weight*100)))
6784

6885
def add_attr(self, name, values):
6986
"""
@@ -111,7 +128,7 @@ def trim(self, by_inflow=True, by_outflow=True, min_count=0):
111128
if min_count > 0:
112129
mask[[i for i in range(self.nnodes) if totcounts[i] < min_count]] = False
113130

114-
self.trim_indices = [i for i in range(self.nnodes) if mask[i] is True]
131+
self.trim_indices = [i for i in range(self.nnodes) if mask[i] == True]
115132
self.trim_graph = self.graph.subgraph(self.trim_indices)
116133

117134
tmp_arr = self.countmat.toarray()[mask,...][...,mask]
@@ -140,7 +157,7 @@ def calc_eig_weights(self,label='eig_weights'):
140157
else:
141158
# use trimmed transition matrix
142159
wts = eig_weights(self.trim_transmat)
143-
full_wts = np.zeros(self.nnodes,dtype=float64)
160+
full_wts = np.zeros(self.nnodes,dtype=float)
144161
for i,ind in enumerate(self.trim_indices):
145162
full_wts[ind] = wts[i]
146163
self.add_attr(label, full_wts)

csnanalysis/matrix.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@
22
import scipy
33
import numpy as np
44

5+
def count_to_trans(countmat):
6+
"""
7+
Converts a count matrix (in scipy sparse format) to a transition
8+
matrix.
9+
"""
10+
tmp = np.array(countmat.toarray(),dtype=float)
11+
colsums = tmp.sum(axis=0)
12+
for i,c in enumerate(colsums):
13+
if c > 0:
14+
tmp[:,i] /= c
15+
16+
return(scipy.sparse.coo_matrix(tmp))
17+
18+
def symmetrize(countmat):
19+
"""
20+
Symmetrizes a count matrix (in scipy sparse format).
21+
"""
22+
return 0.5*(countmat + countmat.transpose())
23+
524
def make_sink(transmat,sink_states):
625
"""
726
Constructs a transition matrix with "sink states", where the columns are

0 commit comments

Comments
 (0)