|
1 | 1 | import scipy |
2 | 2 | import networkx as nx |
3 | 3 | import numpy as np |
4 | | -from csnanalysis.matrix import eig_weights, mult_weights |
| 4 | +from csnanalysis.matrix import eig_weights, mult_weights, committor |
5 | 5 |
|
6 | 6 | def count_to_trans(countmat): |
7 | 7 | """ |
@@ -63,6 +63,7 @@ def to_gephi(self, cols='all', node_name='node.csv', edge_name='edge.csv'): |
63 | 63 | """ |
64 | 64 | Writes node and edge files for import into the Gephi network visualization program. |
65 | 65 | """ |
| 66 | + |
66 | 67 |
|
67 | 68 | def add_attr(self, name, values): |
68 | 69 | """ |
@@ -170,3 +171,46 @@ def calc_mult_weights(self,label='mult_weights',tol=1e-6): |
170 | 171 |
|
171 | 172 | return full_wts |
172 | 173 |
|
| 174 | + def calc_committors(self,basins,labels=None,basin_labels=None,add_basins=False,tol=1e-6,maxstep=20): |
| 175 | + """ |
| 176 | + Calculates committor probabilities between an arbitrary set of N basins. |
| 177 | +
|
| 178 | + basins -- A list of lists, describing which states make up the |
| 179 | + basins of attraction. There can be any number of basins. |
| 180 | + e.g. [[basin1_a,basin1_b,...],[basin2_a,basin2_b,...]] |
| 181 | + labels -- A list of labels given to the committors (one for each |
| 182 | + basin) in the attribute list. |
| 183 | + add_basins -- Whether to add basin vectors to attribute list. |
| 184 | + basin_labels -- List of names of the basins. |
| 185 | + tol -- Tolerance of iterative multiplication process |
| 186 | + (see matrix.trans_mult_iter) |
| 187 | + maxstep -- Maximum number of iteractions of multiplication process. |
| 188 | +
|
| 189 | + The committors are also returned from the function as a numpy array. |
| 190 | + """ |
| 191 | + |
| 192 | + if self.trim_transmat is None: |
| 193 | + # use full transition matrix |
| 194 | + full_comm = committor(self.transmat,basins,tol=tol,maxstep=maxstep) |
| 195 | + else: |
| 196 | + # use trimmed transition matrix |
| 197 | + comm = committor(self.transmat,basins,tol=tol,maxstep=maxstep) |
| 198 | + full_comm = np.zeros(self.nnodes,dtype=float64) |
| 199 | + for i,ind in enumerate(self.trim_indices): |
| 200 | + full_comm[ind] = comm[i] |
| 201 | + |
| 202 | + if labels is None: |
| 203 | + labels = ['p' + str(i) for i in range(len(basins))] |
| 204 | + for i,b in enumerate(basins): |
| 205 | + self.add_attr(labels[i], full_comm[:,i]) |
| 206 | + |
| 207 | + if add_basins: |
| 208 | + if basin_labels is None: |
| 209 | + basin_labels = [str(i) for i in range(len(basins))] |
| 210 | + for i,b in enumerate(basins): |
| 211 | + bvec = np.zeros(self.nnodes,dtype=int) |
| 212 | + bvec[b] = 1 |
| 213 | + self.add_attr(basin_labels[i],bvec) |
| 214 | + |
| 215 | + return full_comm |
| 216 | + |
0 commit comments