11import scipy
22import networkx as nx
33import numpy as np
4+ from csnanalysis .matrix import eig_weights , mult_weights
45
56def count_to_trans (countmat ):
67 """
78 Converts a count matrix (in scipy sparse format) to a transition
89 matrix.
910 """
1011 tmp = np .array (countmat .toarray (),dtype = float )
11- colsums = tmp .sum (axis = 1 )
12+ colsums = tmp .sum (axis = 0 )
1213 for i ,c in enumerate (colsums ):
1314 if c > 0 :
14- tmp [i ] /= c
15+ tmp [:, i ] /= c
1516
1617 return (scipy .sparse .coo_matrix (tmp ))
1718
@@ -26,7 +27,7 @@ class CSN(object):
2627 def __init__ (self , counts , symmetrize = False ):
2728 """
2829 Initializes a CSN object using a counts matrix. This can either be a numpy array,
29- a scipy sparse matrix, or a list of lists.
30+ a scipy sparse matrix, or a list of lists. Indices: [to][from], (or, [row][column]).
3031 """
3132 if type (counts ) is list :
3233 self .countmat = scipy .sparse .coo_matrix (counts )
@@ -49,12 +50,14 @@ def __init__(self, counts, symmetrize=False):
4950
5051 self .nnodes = self .countmat .shape [0 ]
5152 self .transmat = count_to_trans (self .countmat )
53+
54+ self .trim_transmat = None
5255
5356 # initialize networkX directed graph
5457 self .graph = nx .DiGraph ()
5558 labels = [{'ID' : i } for i in range (self .nnodes )]
5659 self .graph .add_nodes_from (zip (range (self .nnodes ),labels ))
57- self .graph .add_weighted_edges_from (zip (self .transmat .row ,self .transmat .col ,self .transmat .data ))
60+ self .graph .add_weighted_edges_from (zip (self .transmat .col ,self .transmat .row ,self .transmat .data ))
5861
5962 def to_gephi (self , cols = 'all' , node_name = 'node.csv' , edge_name = 'edge.csv' ):
6063 """
@@ -118,3 +121,52 @@ def trim_graph(self, by_inflow=True, by_outflow=True, min_count=0):
118121 self .trim_nnodes = self .trim_countmat .shape [0 ]
119122 self .trim_transmat = count_to_trans (self .trim_countmat )
120123
124+
125+ def calc_eig_weights (self ,label = 'eig_weights' ):
126+ """
127+ Calculates weights of states using the highest Eigenvalue of the
128+ transition matrix. By default it uses self.trim_transmat, but will
129+ use self.transmat if no trimming has been done.
130+
131+ The weights are stored as node attributes in self.graph with the label
132+ 'label', and are also returned from the function.
133+ """
134+
135+ if self .trim_transmat is None :
136+ # use full transition matrix
137+ full_wts = eig_weights (self .transmat )
138+ self .add_attr (label , full_wts )
139+ else :
140+ # use trimmed transition matrix
141+ wts = eig_weights (self .trim_transmat )
142+ full_wts = np .zeros (self .nnodes ,dtype = float64 )
143+ for i ,ind in enumerate (self .trim_indices ):
144+ full_wts [ind ] = wts [i ]
145+ self .add_attr (label , full_wts )
146+
147+ return full_wts
148+
149+ def calc_mult_weights (self ,label = 'mult_weights' ,tol = 1e-6 ):
150+ """
151+ Calculates weights of states using iterative multiplication of the
152+ transition matrix. By default it uses self.trim_transmat, but will
153+ use self.transmat if no trimming has been done.
154+
155+ The weights are stored as node attributes in self.graph with the label
156+ 'label', and are also returned from the function.
157+ """
158+
159+ if self .trim_transmat is None :
160+ # use full transition matrix
161+ full_wts = mult_weights (self .transmat ,tol )
162+ self .add_attr (label , full_wts )
163+ else :
164+ # use trimmed transition matrix
165+ wts = mult_weights (self .trim_transmat ,tol )
166+ full_wts = np .zeros (self .nnodes ,dtype = float64 )
167+ for i ,ind in enumerate (self .trim_indices ):
168+ full_wts [ind ] = wts [i ]
169+ self .add_attr (label , full_wts )
170+
171+ return full_wts
172+
0 commit comments