|
1 | 1 | import scipy |
2 | 2 | import networkx as nx |
3 | 3 | import numpy as np |
4 | | -from csnanalysis.matrix import eig_weights, mult_weights, committor |
5 | | - |
6 | | -def count_to_trans(countmat): |
7 | | - """ |
8 | | - Converts a count matrix (in scipy sparse format) to a transition |
9 | | - matrix. |
10 | | - """ |
11 | | - tmp = np.array(countmat.toarray(),dtype=float) |
12 | | - colsums = tmp.sum(axis=0) |
13 | | - for i,c in enumerate(colsums): |
14 | | - if c > 0: |
15 | | - tmp[:,i] /= c |
16 | | - |
17 | | - return(scipy.sparse.coo_matrix(tmp)) |
18 | | - |
19 | | -def symmetrize(countmat): |
20 | | - """ |
21 | | - Symmetrizes a count matrix (in scipy sparse format). |
22 | | - """ |
23 | | - return 0.5*(countmat + countmat.transpose()) |
| 4 | +from csnanalysis.matrix import * |
24 | 5 |
|
25 | 6 | class CSN(object): |
26 | 7 |
|
@@ -55,15 +36,51 @@ def __init__(self, counts, symmetrize=False): |
55 | 36 |
|
56 | 37 | # initialize networkX directed graph |
57 | 38 | self.graph = nx.DiGraph() |
58 | | - labels = [{'ID' : i} for i in range(self.nnodes)] |
| 39 | + labels = [{'label' : i, 'ID' : i} for i in range(self.nnodes)] |
59 | 40 | self.graph.add_nodes_from(zip(range(self.nnodes),labels)) |
60 | 41 | self.graph.add_weighted_edges_from(zip(self.transmat.col,self.transmat.row,self.transmat.data)) |
61 | 42 |
|
62 | | - def to_gephi(self, cols='all', node_name='node.csv', edge_name='edge.csv'): |
| 43 | + def to_gephi(self, cols='all', node_name='node.csv', edge_name='edge.csv', directed=False): |
63 | 44 | """ |
64 | 45 | Writes node and edge files for import into the Gephi network visualization program. |
| 46 | +
|
| 47 | + cols -- A list of columns that should be written to the node file. ID and label are |
| 48 | + included by default. 'all' will include every attribute attached to the |
| 49 | + nodes in self.graph. |
| 50 | +
|
65 | 51 | """ |
| 52 | + if cols == 'all': |
| 53 | + cols = list(self.graph.node[0].keys()) |
| 54 | + else: |
| 55 | + if 'label' not in cols: |
| 56 | + cols = ['label'] + cols |
| 57 | + if 'ID' not in cols: |
| 58 | + cols = ['ID'] + cols |
66 | 59 |
|
| 60 | + with open(node_name,mode='w') as f: |
| 61 | + f.write(" ".join(cols)+"\n") |
| 62 | + for i in range(self.nnodes): |
| 63 | + data = [str(self.graph.node[i][c]) for c in cols] |
| 64 | + f.write(' '.join(data)+"\n") |
| 65 | + |
| 66 | + # compute edge weights |
| 67 | + if directed: |
| 68 | + with open(edge_name,mode='w') as f: |
| 69 | + f.write("source target type prob i_weight\n") |
| 70 | + for from_ind,edge_dict in self.graph.edge.items(): |
| 71 | + for to_ind,edge in edge_dict.items(): |
| 72 | + f.write("{0:d} {1:d} {2:s} {3:f} {4:d}\n".format(from_ind,to_ind,'Directed',edge['weight'],int(edge['weight']*100))) |
| 73 | + else: |
| 74 | + with open(edge_name,mode='w') as f: |
| 75 | + f.write("source target type prob i_weight\n") |
| 76 | + for from_ind,edge_dict in self.graph.edge.items(): |
| 77 | + for to_ind,edge in edge_dict.items(): |
| 78 | + if from_ind <= to_ind: |
| 79 | + if to_ind in self.graph.edge and from_ind in self.graph.edge[to_ind]: |
| 80 | + edge_weight = 0.5*(self.graph.edge[to_ind][from_ind]['weight'] + edge['weight']) |
| 81 | + else: |
| 82 | + edge_weight = 0.5*edge['weight'] |
| 83 | + f.write("{0:d} {1:d} {2:s} {3:f} {4:d}\n".format(from_ind,to_ind,'Undirected',edge_weight,int(edge_weight*100))) |
67 | 84 |
|
68 | 85 | def add_attr(self, name, values): |
69 | 86 | """ |
@@ -111,7 +128,7 @@ def trim(self, by_inflow=True, by_outflow=True, min_count=0): |
111 | 128 | if min_count > 0: |
112 | 129 | mask[[i for i in range(self.nnodes) if totcounts[i] < min_count]] = False |
113 | 130 |
|
114 | | - self.trim_indices = [i for i in range(self.nnodes) if mask[i] is True] |
| 131 | + self.trim_indices = [i for i in range(self.nnodes) if mask[i] == True] |
115 | 132 | self.trim_graph = self.graph.subgraph(self.trim_indices) |
116 | 133 |
|
117 | 134 | tmp_arr = self.countmat.toarray()[mask,...][...,mask] |
@@ -140,7 +157,7 @@ def calc_eig_weights(self,label='eig_weights'): |
140 | 157 | else: |
141 | 158 | # use trimmed transition matrix |
142 | 159 | wts = eig_weights(self.trim_transmat) |
143 | | - full_wts = np.zeros(self.nnodes,dtype=float64) |
| 160 | + full_wts = np.zeros(self.nnodes,dtype=float) |
144 | 161 | for i,ind in enumerate(self.trim_indices): |
145 | 162 | full_wts[ind] = wts[i] |
146 | 163 | self.add_attr(label, full_wts) |
|
0 commit comments