4949from tqdm import tqdm
5050import torch
5151
52+ from pypef .llm .utils import get_batches
53+ from pypef .utils .variant_data import get_mismatches
54+
5255
5356class GREMLIN :
5457 """
@@ -65,6 +68,8 @@ def __init__(
6568 eff_cutoff = 0.8 ,
6669 opt_iter = 100 ,
6770 max_msa_seqs : int | None = 10000 ,
71+ msa_start : None | int = None ,
72+ msa_end : None | int = None ,
6873 seqs : list [str ] | np .ndarray [str ] | None = None ,
6974 device : str | None = None
7075 ):
@@ -98,12 +103,20 @@ def __init__(
98103 else :
99104 self .max_msa_seqs = max_msa_seqs
100105 self .states = len (self .char_alphabet )
106+ self .msa_start = msa_start
107+ if msa_end == 0 :
108+ msa_end = None
109+ self .msa_end = msa_end
101110 logger .info ('Loading MSA...' )
102111 if seqs is None :
103112 self .seqs , self .seq_ids = self .get_sequences_from_msa (alignment )
104113 else :
105114 self .seqs = seqs
106115 self .seq_ids = np .array ([n for n in range (len (self .seqs ))])
116+ self .first_msa_seq = self .seqs [0 ]
117+ if self .msa_start is not None or self .msa_end is not None :
118+ logger .info (f'Trimmed sequence length.. first sequence is printed here as '
119+ f'example (Length: { len (self .first_msa_seq )} ): { self .first_msa_seq } ' )
107120 logger .info (f'Found { len (self .seqs )} sequences in the MSA...' )
108121 self .msa_ori = self .get_msa_ori ()
109122 logger .info (f'MSA shape: { np .shape (self .msa_ori )} ' )
@@ -153,7 +166,14 @@ def get_sequences_from_msa(self, msa_file: str):
153166 with open (msa_file , 'r' ) as fh :
154167 alignment = AlignIO .read (fh , "fasta" )
155168 for record in alignment :
156- sequences .append (str (record .seq ))
169+ seq = str (record .seq )
170+ if self .msa_start is not None and self .msa_end is not None :
171+ seq = seq [self .msa_start :self .msa_end ]
172+ elif self .msa_start is not None :
173+ seq = seq [self .msa_start :]
174+ elif self .msa_end is not None :
175+ seq = seq [:self .msa_end ]
176+ sequences .append (seq )
157177 seq_ids .append (str (record .id ))
158178 assert len (sequences ) == len (seq_ids ), f"{ len (sequences )} , { len (seq_ids )} "
159179 return np .array (sequences ), np .array (seq_ids )
@@ -353,14 +373,14 @@ def run_optimization(self):
353373
354374 self .mt_v , self .vt_v = torch .zeros_like (self .v ), torch .zeros_like (self .v )
355375 self .mt_w , self .vt_w = torch .zeros_like (self .w ), torch .zeros_like (self .w )
356- logger .info (f'Initial loss: { self ._loss ()} ' )
376+ logger .info (f'Initial loss: { self ._loss ():.5f } ' )
357377 for i in range (self .opt_iter ):
358378 self .opt_adam_step ()
359379 try :
360380 if (i + 1 ) % int (self .opt_iter / 10 ) == 0 :
361- logger .info (f'Loss step { i + 1 } : { self ._loss ()} ' )
381+ logger .info (f'Loss step { i + 1 } : { self ._loss ():.5f } ' )
362382 except ZeroDivisionError :
363- logger .info (f'Loss step { i + 1 } : { self ._loss ()} ' )
383+ logger .info (f'Loss step { i + 1 } : { self ._loss ():.5f } ' )
364384
365385 self .v = self .v .detach ().cpu ().numpy ()
366386 self .w = self .w .detach ().cpu ().numpy ()
@@ -416,7 +436,20 @@ def get_scores(self, seqs, v=None, w=None, v_idx=None, encode=False, h_wt_seq=0.
416436 if v_idx is None :
417437 v_idx = self .v_idx
418438 seqs_int = self .seq2int (seqs )
419-
439+ wt_seq_len = len (self .wt_seq )
440+ #if np.shape(seqs_int)[1] != wt_seq_len:
441+ # raise RuntimeError(
442+ # f"Input sequence shape (length: {np.shape(seqs_int)[1]}) does not match GREMLIN "
443+ # f"MSA shape (common sequence length: {wt_seq_len}) inferred from the MSA."
444+ # )
445+ # Check nums of mutations to MSA first/WT sequence and gives warning if too apart from MSA seq
446+ for i , seq in enumerate (seqs ):
447+ n_mismatches , mismatches = get_mismatches (self .wt_seq , seq )
448+ if n_mismatches / wt_seq_len > 0.05 :
449+ logger .warning (
450+ f"Sequence { mismatches } contains more than 5% sequence mismatches to the "
451+ f"first MSA/\" WT\" sequence. Effect predictions will likely be incorrect!"
452+ )
420453 try :
421454 if seqs_int .shape [- 1 ] != len (v_idx ): # The input sequence length ({seqs_int.shape[-1]})
422455 # does not match the common gap-trimmed MSA sequence length (len(v_idx)
@@ -471,8 +504,16 @@ def collect_encoded_sequences(self, seqs, v=None, w=None, v_idx=None):
471504 Wrapper function for encoding input sequences using the self.get_scores
472505 function with encode set to True.
473506 """
474- xs = self .get_scores (seqs , v , w , v_idx , encode = True )
475- return xs
507+ xs = []
508+ sequences_batched = get_batches (
509+ seqs , batch_size = 1000 , dtype = str ,
510+ keep_remaining = True , verbose = True
511+ )
512+ sequences_batched = np .atleast_2d (sequences_batched )
513+
514+ for seq_batch in sequences_batched :
515+ xs .append (self .get_scores (seq_batch , v , w , v_idx , encode = True ))
516+ return xs [0 ]
476517
477518 @staticmethod
478519 def normalize (apc_mat ):
@@ -691,7 +732,7 @@ def save_corr_csv(gremlin: GREMLIN, min_distance: int = 0, sort_by: str = 'apc')
691732 )
692733 df_mtx_sorted_mindist .to_csv (f"coevolution_{ sort_by } _sorted.csv" , sep = ',' )
693734 logger .info (f"Saved coevolution CSV data as "
694- f"{ os .path .abspath (f'coevolution_{ sort_by } _sorted.csv' )} " )
735+ f"{ os .path .abspath (f'coevolution_{ sort_by } _sorted.csv' )} " )
695736
696737
697738def plot_predicted_ssm (gremlin : GREMLIN ):
0 commit comments