@@ -79,7 +79,6 @@ def __init__(
7979 self .eff_cutoff = eff_cutoff
8080 self .opt_iter = opt_iter
8181 self .states = len (self .char_alphabet )
82- self .a2n = self .a2n_dict ()
8382 self .seqs , _ , _ = get_sequences_from_file (alignment )
8483 self .msa_ori = self .get_msa_ori ()
8584 self .n_col_ori = self .msa_ori .shape [1 ]
@@ -97,7 +96,7 @@ def __init__(
9796 self .n_eff = np .sum (self .msa_weights )
9897 self .n_row = self .msa_trimmed .shape [0 ]
9998 self .n_col = self .msa_trimmed .shape [1 ]
100- self .v_ini , self .w_ini = self .initialize_v_w (remove_gap_entries = False )
99+ self .v_ini , self .w_ini , self . aa_counts = self .initialize_v_w (remove_gap_entries = False )
101100 self .optimize = optimize
102101 if self .optimize :
103102 self .v_opt , self .w_opt = self .run_opt_tf ()
@@ -110,27 +109,30 @@ def a2n_dict(self):
110109 return a2n
111110
112111 def aa2int (self , aa ):
113- """convert single aa into numerical integer value, e.g.
112+ """convert single aa into numerical integer value, e.g.:
114113 "A" -> 0 or "-" to 21 dependent on char_alphabet"""
115- if aa in self .a2n :
116- return self .a2n [aa ]
114+ a2n = self .a2n_dict ()
115+ if aa in a2n :
116+ return a2n [aa ]
117117 else : # for unknown characters insert Gap character
118- return self . a2n ['-' ]
118+ return a2n ['-' ]
119119
120- def str2int (self , x ):
120+ def seq2int (self , aa_seqs ):
121121 """
122- convert a list of strings into list of integers
123- Example: ["ACD","EFG"] -> [[0,4,3], [6,13,7]]
122+ convert a single sequence or a list of sequences into a list of integer sequences, e.g.:
123+ ["ACD","EFG"] -> [[0,4,3], [6,13,7]]
124124 """
125- if type (x ) == list :
126- x = np .array (x )
127- if x .dtype .type is np .str_ :
128- if x .ndim == 0 : # single seq
129- return np .array ([self .aa2int (aa ) for aa in str (x )])
125+ if type (aa_seqs ) == str :
126+ aa_seqs = np .array (aa_seqs )
127+ if type (aa_seqs ) == list :
128+ aa_seqs = np .array (aa_seqs )
129+ if aa_seqs .dtype .type is np .str_ :
130+ if aa_seqs .ndim == 0 : # single seq
131+ return np .array ([self .aa2int (aa ) for aa in str (aa_seqs )])
130132 else : # list of seqs
131- return np .array ([[self .aa2int (aa ) for aa in seq ] for seq in x ])
133+ return np .array ([[self .aa2int (aa ) for aa in seq ] for seq in aa_seqs ])
132134 else :
133- return x
135+ return aa_seqs
134136
135137 @property
136138 def get_v_idx_w_idx (self ):
@@ -150,7 +152,7 @@ def filt_gaps(self, msa_ori):
150152 non_gaps = np .where (np .sum (tmp .T , - 1 ).T / msa_ori .shape [0 ] < self .gap_cutoff )[0 ]
151153
152154 gaps = np .where (np .sum (tmp .T , - 1 ).T / msa_ori .shape [0 ] >= self .gap_cutoff )[0 ]
153- logger .info (f'Gap positions (removed from msa ):\n { gaps } ' )
155+ logger .info (f'Gap positions (removed from MSA; 0-indexed ):\n { gaps } ' )
154156 ncol_trimmed = len (non_gaps )
155157 v_idx = non_gaps
156158 w_idx = v_idx [np .stack (np .triu_indices (ncol_trimmed , 1 ), - 1 )]
@@ -362,17 +364,19 @@ def initialize_v_w(self, remove_gap_entries=True):
362364 """
363365 w_ini = np .zeros ((self .n_col , self .states , self .n_col , self .states ))
364366 onehot_cat_msa = np .eye (self .states )[self .msa_trimmed ]
367+ aa_counts = np .sum (onehot_cat_msa , axis = 0 )
365368 pseudo_count = 0.01 * np .log (self .n_eff )
366369 v_ini = np .log (np .sum (onehot_cat_msa .T * self .msa_weights , - 1 ).T + pseudo_count )
367370 v_ini = v_ini - np .mean (v_ini , - 1 , keepdims = True )
368- # loss_score_ini = self.objective(v_ini, w_ini, flattened=False) # * self.n_eff
371+ # loss_score_ini = self.objective(v_ini, w_ini, flattened=False)
369372
370373 if remove_gap_entries :
371374 no_gap_states = self .states - 1
372375 v_ini = v_ini [:, :no_gap_states ]
373376 w_ini = w_ini [:, :no_gap_states , :, :no_gap_states ]
377+ aa_counts = aa_counts [:, :no_gap_states ]
374378
375- return v_ini , w_ini
379+ return v_ini , w_ini , aa_counts
376380
377381 @property
378382 def get_v_w_opt (self ):
@@ -390,10 +394,10 @@ def get_score(self, seqs, v=None, w=None, v_idx=None, encode=False, h_wt_seq=0.0
390394 if self .optimize :
391395 v , w = self .v_opt , self .w_opt
392396 else :
393- v , w = self .v_ini , self . w_ini
397+ v , w , _ = self .initialize_v_w ( remove_gap_entries = True )
394398 if v_idx is None :
395399 v_idx = self .v_idx
396- seqs_int = self .str2int (seqs )
400+ seqs_int = self .seq2int (seqs )
397401 # if length of sequence != length of model use only
398402 # valid positions (v_idx) from the trimmed alignment
399403 try :
@@ -439,7 +443,7 @@ def get_score(self, seqs, v=None, w=None, v_idx=None, encode=False, h_wt_seq=0.0
439443 else :
440444 return np .sum (h , axis = - 1 ) - h_wt_seq
441445
442- def get_wt_score (self , wt_seq = None , v = None , w = None ):
446+ def get_wt_score (self , wt_seq = None , v = None , w = None , encode = False ):
443447 if wt_seq is None :
444448 wt_seq = self .wt_seq
445449 if v is None or w is None :
@@ -448,7 +452,7 @@ def get_wt_score(self, wt_seq=None, v=None, w=None):
448452 else :
449453 v , w = self .v_ini , self .w_ini
450454 wt_seq = np .array (wt_seq , dtype = str )
451- return self .get_score (wt_seq , v , w )
455+ return self .get_score (wt_seq , v , w , encode = encode )
452456
453457 def collect_encoded_sequences (self , seqs , v = None , w = None , v_idx = None ):
454458 """
@@ -541,17 +545,22 @@ def plot_correlation_matrix(self, matrix_type: str = 'apc', set_diag_zero=True):
541545 else :
542546 ax .imshow (matrix , cmap = 'Blues' )
543547 tick_pos = ax .get_xticks ()
548+ tick_pos = np .array ([int (t ) for t in tick_pos ])
544549 tick_pos [- 1 ] = matrix .shape [0 ]
545- tick_pos [2 :] -= 1
550+ if tick_pos [2 ] > 1 :
551+ tick_pos [2 :] -= 1
546552 ax .set_xticks (tick_pos )
547553 ax .set_yticks (tick_pos )
548554 labels = [item .get_text () for item in ax .get_xticklabels ()]
549- labels = [labels [0 ]] + [str (int (label ) + 1 ) for label in labels [1 :]]
555+ try :
556+ labels = [labels [0 ]] + [str (int (label ) + 1 ) for label in labels [1 :]]
557+ except ValueError :
558+ pass
550559 ax .set_xticklabels (labels )
551560 ax .set_yticklabels (labels )
552561 ax .set_xlim (- 1 , matrix .shape [0 ])
553562 ax .set_ylim (- 1 , matrix .shape [0 ])
554- plt .title (matrix_type )
563+ plt .title (matrix_type . upper () )
555564 plt .savefig (f'{ matrix_type } .png' , dpi = 500 )
556565 plt .close ('all' )
557566
0 commit comments