Skip to content

Commit 0220639

Browse files
committed
committor wrapper
1 parent 4bd559f commit 0220639

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

csnanalysis/csn.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import scipy
22
import networkx as nx
33
import numpy as np
4-
from csnanalysis.matrix import eig_weights, mult_weights
4+
from csnanalysis.matrix import eig_weights, mult_weights, committor
55

66
def count_to_trans(countmat):
77
"""
@@ -63,6 +63,7 @@ def to_gephi(self, cols='all', node_name='node.csv', edge_name='edge.csv'):
6363
"""
6464
Writes node and edge files for import into the Gephi network visualization program.
6565
"""
66+
6667

6768
def add_attr(self, name, values):
6869
"""
@@ -170,3 +171,46 @@ def calc_mult_weights(self,label='mult_weights',tol=1e-6):
170171

171172
return full_wts
172173

174+
def calc_committors(self,basins,labels=None,basin_labels=None,add_basins=False,tol=1e-6,maxstep=20):
175+
"""
176+
Calculates committor probabilities between an arbitrary set of N basins.
177+
178+
basins -- A list of lists, describing which states make up the
179+
basins of attraction. There can be any number of basins.
180+
e.g. [[basin1_a,basin1_b,...],[basin2_a,basin2_b,...]]
181+
labels -- A list of labels given to the committors (one for each
182+
basin) in the attribute list.
183+
add_basins -- Whether to add basin vectors to attribute list.
184+
basin_labels -- List of names of the basins.
185+
tol -- Tolerance of iterative multiplication process
186+
(see matrix.trans_mult_iter)
187+
maxstep -- Maximum number of iteractions of multiplication process.
188+
189+
The committors are also returned from the function as a numpy array.
190+
"""
191+
192+
if self.trim_transmat is None:
193+
# use full transition matrix
194+
full_comm = committor(self.transmat,basins,tol=tol,maxstep=maxstep)
195+
else:
196+
# use trimmed transition matrix
197+
comm = committor(self.transmat,basins,tol=tol,maxstep=maxstep)
198+
full_comm = np.zeros(self.nnodes,dtype=float64)
199+
for i,ind in enumerate(self.trim_indices):
200+
full_comm[ind] = comm[i]
201+
202+
if labels is None:
203+
labels = ['p' + str(i) for i in range(len(basins))]
204+
for i,b in enumerate(basins):
205+
self.add_attr(labels[i], full_comm[:,i])
206+
207+
if add_basins:
208+
if basin_labels is None:
209+
basin_labels = [str(i) for i in range(len(basins))]
210+
for i,b in enumerate(basins):
211+
bvec = np.zeros(self.nnodes,dtype=int)
212+
bvec[b] = 1
213+
self.add_attr(basin_labels[i],bvec)
214+
215+
return full_comm
216+

0 commit comments

Comments
 (0)