Skip to content

Commit 9698a85

Browse files
committed
fixed iter mult bug
1 parent f8750cf commit 9698a85

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

csnanalysis/csn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,13 @@ def calc_committors(self,basins,labels=None,basin_labels=None,add_basins=False,t
329329

330330
return full_comm
331331

332+
def idxs_to_trim(self,idxs):
333+
"""
334+
Converts a list of idxs to trim_idxs.
335+
336+
idxs -- List of states in the transition matrix. Elements should be
337+
integers from 0 to nstates.
338+
"""
339+
340+
return [self.trim_indices.index(i) for i in idxs if i in self.trim_indices]
341+

csnanalysis/matrix.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
import scipy
33
import numpy as np
4+
from itertools import compress
45

56
def count_to_trans(countmat):
67
"""
@@ -46,17 +47,20 @@ def _make_sink(transmat,sink_states):
4647
sink_mat.data[i] = 1.
4748
set_to_one[sink_states.index(sink_mat.col[i])] = True
4849

49-
# set diagonal elements to 1
50-
for i in range(len(sink_states)):
51-
if not set_to_one[i]:
52-
# add element sink_mat[sink_states[i]][sink_states[i]] = 1
53-
np.append(sink_mat.row,sink_states[i])
54-
np.append(sink_mat.col,sink_states[i])
55-
np.append(sink_mat.data,1.)
56-
50+
# set diagonal elements to 1 that haven't been set to one already
51+
statelist = list(compress(sink_states, np.logical_not(set_to_one)))
52+
sink_mat.row = np.append(sink_mat.row,statelist)
53+
sink_mat.col = np.append(sink_mat.col,statelist)
54+
sink_mat.data = np.append(sink_mat.data,[1 for i in statelist])
55+
5756
# remove zeros
5857
sink_mat.eliminate_zeros()
5958

59+
# check if sink_mat columns still sum to 1
60+
minval = sink_mat.toarray().sum(axis=0).min()
61+
if minval < 0.99999:
62+
raise ValueError("Error! Columns no longer sum to one in _make_sink!")
63+
6064
return sink_mat
6165

6266
def eig_weights(transmat):
@@ -107,14 +111,14 @@ def _trans_mult_iter(transmat,tol,maxstep=20):
107111

108112
var = 1
109113
step = 0
110-
while var > tol or step > maxstep:
114+
while (var > tol) and (step < maxstep):
111115
newmat = np.matmul(t,t)
112116
var = np.abs(newmat-t).max()
113117
t = newmat.copy()
114118
step += 1
115119

116-
if step > maxstep:
117-
print("Warning: iterative multiplication not converged after",maxstep,"steps: (var = ",var)
120+
if step == maxstep and var > tol:
121+
print("Warning: iterative multiplication not converged after",step,"steps: (var = ",var,"), (tol = ",tol,")")
118122

119123
return t
120124

@@ -198,7 +202,7 @@ def _extend(transmat,hubstates):
198202
ext_mat = scipy.sparse.coo_matrix((data, (rows, cols)), shape=(2*n,2*n))
199203
return ext_mat
200204

201-
def _getring(transmat,basin,eig_weights,tol,maxstep):
205+
def _getring(transmat,basin,wts,tol,maxstep):
202206
"""
203207
Given a transition matrix, and a set of states that form a basin,
204208
this returns a vector describing how probability exits that basin.
@@ -215,11 +219,11 @@ def _getring(transmat,basin,eig_weights,tol,maxstep):
215219
for b in basin:
216220
for i in range(n):
217221
if i not in basin:
218-
ringprob[i] += sink_results[i][b]
219-
220-
return ringprob
222+
ringprob[i] += wts[b]*sink_results[i][b]
223+
224+
return ringprob/wts[basin].sum()
221225

222-
def hubscores(transmat,hubstates,basins,tol=1e-6,maxstep=20,wts=None):
226+
def hubscores(transmat,hubstates,basins,tol=1e-6,maxstep=30,wts=None):
223227
"""
224228
This function computes hub scores, which are the probabilities that
225229
transitions between a set of communities will use a given community as
@@ -261,8 +265,9 @@ def hubscores(transmat,hubstates,basins,tol=1e-6,maxstep=20,wts=None):
261265
if wts is None:
262266
wts = eig_weights(transmat)
263267

268+
264269
h = np.zeros((2,2),dtype=float)
265-
ring = [_getring(transmat,b,eig_weights,tol,maxstep) for b in basins]
270+
ring = [_getring(transmat,b,wts,tol,maxstep) for b in basins]
266271

267272
for source,sink in [[0,1],[1,0]]:
268273
for i,p in enumerate(ring[source]):

0 commit comments

Comments
 (0)