3737from sklearn .linear_model import Ridge
3838from sklearn .model_selection import GridSearchCV , train_test_split
3939from scipy .optimize import differential_evolution
40+ from Bio import SeqIO , BiopythonParserWarning
41+ warnings .filterwarnings (action = 'ignore' , category = BiopythonParserWarning )
4042
4143from pypef .utils .variant_data import (
4244 get_sequences_from_file , get_seqs_from_var_name ,
@@ -327,16 +329,28 @@ def _adjust_betas(
327329 return minimizer .x
328330
329331 def get_subsplits_train (self , train_size_fit : float = 0.66 ):
332+ print ("Getting subsplits for supervised (re-)training of models "
333+ "and for adjustment of hybrid component contribution "
334+ "weights (\" beta's\" )..."
335+ )
336+ train_size_fit = int (train_size_fit * len (self .y_train ))
337+ train_size_beta_adjustment = len (self .y_train ) - train_size_fit
338+ print (f"Splitting training data of size { len (self .y_train )} "
339+ f"into { train_size_fit } variants for model tuning and "
340+ f"{ train_size_beta_adjustment } variants for hybrid model "
341+ f"beta adjustment..." )
330342 if len (self .parameter_range ) == 4 :
331343 # Reduce sizes by batch modulo
332- train_size_fit = int (
333- (train_size_fit * len (self .y_train )) -
334- ((train_size_fit * len (self .y_train )) % self .batch_size )
335- )
336- #train_test_size = int(
337- # (len(self.y_train) - train_size_fit) -
338- # ((len(self.y_train) - train_size_fit) % self.batch_size)
339- #)
344+ n_drop = train_size_fit % self .batch_size
345+ if n_drop > 0 :
346+ train_size_fit = train_size_fit - n_drop
347+ train_size_beta_adjustment = len (self .y_train ) - train_size_fit
348+ print (f"Shifting { n_drop } variants from training set to "
349+ f"beta adjustment set to match batch requirements "
350+ f"of batch size { self .batch_size } for LLM retraining "
351+ f"resulting in { train_size_fit } variants for model "
352+ f"tuning and { train_size_beta_adjustment } variants "
353+ f"for hybrid model beta adjustment..." )
340354 (
341355 self .x_dca_ttrain , self .x_dca_ttest ,
342356 self .x_llm_ttrain , self .x_llm_ttest ,
@@ -348,14 +362,6 @@ def get_subsplits_train(self, train_size_fit: float = 0.66):
348362 train_size = train_size_fit ,
349363 random_state = self .seed
350364 )
351- # Reducing by batch size modulo for X and y
352- self .x_dca_ttrain = self .x_dca_ttrain [:train_size_fit ]
353- self .x_llm_ttrain = self .x_llm_ttrain [:train_size_fit ]
354- self .y_ttrain = self .y_ttrain [:train_size_fit ]
355- #self.x_dca_ttest = self.x_dca_ttest[:train_test_size]
356- #self.x_llm_ttest = self.x_llm_ttest[:train_test_size]
357- #self.y_ttest = self.y_ttest[:train_test_size]
358-
359365 else :
360366 (
361367 self .x_dca_ttrain , self .x_dca_ttest ,
@@ -526,12 +532,15 @@ def train_and_optimize(self) -> tuple:
526532 if len (self .parameter_range ) == 4 :
527533 self .train_llm ()
528534 self .beta1 , self .beta2 , self .beta3 , self .beta4 = self ._adjust_betas (
529- self .y_ttest , self .y_dca_ttest , self .y_dca_ridge_ttest , self .y_llm_ttest , self .y_llm_lora_ttest
535+ self .y_ttest , self .y_dca_ttest , self .y_dca_ridge_ttest ,
536+ self .y_llm_ttest , self .y_llm_lora_ttest
530537 )
531538 return self .beta1 , self .beta2 , self .beta3 , self .beta4 , self .ridge_opt
532539
533540 else :
534- self .beta1 , self .beta2 = self ._adjust_betas (self .y_ttest , self .y_dca_ttest , self .y_dca_ridge_ttest )
541+ self .beta1 , self .beta2 = self ._adjust_betas (self .y_ttest ,
542+ self .y_dca_ttest , self .y_dca_ridge_ttest
543+ )
535544 return self .beta1 , self .beta2 , self .ridge_opt
536545
537546
@@ -607,7 +616,10 @@ def hybrid_prediction(
607616 self .llm_model ,
608617 device = self .device ).detach ().cpu ().numpy ()
609618
610- return self .beta1 * y_dca + self .beta2 * y_ridge + self .beta3 * y_llm + self .beta4 * y_llm_lora
619+ return (
620+ self .beta1 * y_dca + self .beta2 * y_ridge +
621+ self .beta3 * y_llm + self .beta4 * y_llm_lora
622+ )
611623
612624 def ls_ts_performance (self ):
613625 beta_1 , beta_2 , reg = self .settings (
@@ -724,15 +736,20 @@ def get_model_path(model: str):
724736 model_path = f'Pickles/{ model } '
725737 else :
726738 raise SystemError (
727- "Did not find specified model file in current working directory "
728- " or /Pickles subdirectory. Make sure to train/save a model first "
729- "(e.g., for saving a GREMLIN model, type \" pypef param_inference --msa TARGET_MSA.a2m\" "
730- "or, for saving a plmc model, type \" pypef param_inference --params TARGET_PLMC.params\" )."
739+ "Did not find specified model file in current "
740+ "working directory or /Pickles subdirectory. "
741+ "Make sure to train/save a model first (e.g., "
742+ "for saving a GREMLIN model, type \" pypef "
743+ "param_inference --msa TARGET_MSA.a2m\" or, for"
744+ "saving a plmc model, type \" pypef param_inference"
745+ " --params TARGET_PLMC.params\" )."
731746 )
732747 return model_path
733748 except TypeError :
734- raise SystemError ("No provided model. "
735- "Specify a model for DCA-based encoding." )
749+ raise SystemError (
750+ "No provided model. Specify a " \
751+ "model for DCA-based encoding."
752+ )
736753
737754
738755def get_model_and_type (
@@ -772,11 +789,7 @@ def get_model_and_type(
772789
773790def save_model_to_dict_pickle (
774791 model : DCALLMHybridModel | PLMC | GREMLIN ,
775- model_type : str | None = None ,
776- beta_1 : float | None = None ,
777- beta_2 : float | None = None ,
778- spearman_r : float | None = None ,
779- regressor : sklearn .base .BaseEstimator = None
792+ model_type : str | None = None
780793):
781794 try :
782795 os .mkdir ('Pickles' )
@@ -790,11 +803,7 @@ def save_model_to_dict_pickle(
790803 pickle .dump (
791804 {
792805 'model' : model ,
793- 'model_type' : model_type ,
794- 'beta_1' : beta_1 ,
795- 'beta_2' : beta_2 ,
796- 'spearman_rho' : spearman_r ,
797- 'regressor' : regressor
806+ 'model_type' : model_type
798807 },
799808 open (f'Pickles/{ model_type } ' , 'wb' )
800809 )
@@ -816,19 +825,21 @@ def plmc_or_gremlin_encoding(
816825 use_global_model = False
817826):
818827 """
819- Decides based on the params file input type which DCA encoding to be performed, i.e.,
820- GREMLIN or PLMC.
821- If use_global_model==True, to avoid each time pickle model file getting loaded, which
822- is quite inefficient when performing directed evolution, i.e., encoding of single
823- sequences, a global model is stored at the first evolution step and used in the
824- subsequent steps.
828+ Decides based on the params file input type which DCA encoding
829+ to be performed, i.e., GREMLIN or PLMC.
830+ If use_global_model==True, to avoid each time pickle model
831+ file getting loaded, which is quite inefficient when performing
832+ directed evolution, i.e., encoding of single sequences, a
833+ global model is stored at the first evolution step and used
834+ in the subsequent steps.
825835 """
826836 global global_model , global_model_type
827837 if ys_true is None :
828838 ys_true = np .zeros (np .shape (sequences ))
829839 if use_global_model :
830840 if global_model is None :
831- global_model , global_model_type = get_model_and_type (params_file , substitution_sep )
841+ global_model , global_model_type = get_model_and_type (
842+ params_file , substitution_sep )
832843 model , model_type = global_model , global_model_type
833844 else :
834845 model , model_type = global_model , global_model_type
@@ -840,12 +851,16 @@ def plmc_or_gremlin_encoding(
840851 )
841852 elif model_type == 'GREMLIN' :
842853 if verbose :
843- print (f"Following positions are frequent gap positions in the MSA "
844- f"and cannot be considered for effective modeling, i.e., "
845- f"substitutions at these positions are removed as these would be "
846- f"predicted with wild-type fitness:\n { [int (gap ) + 1 for gap in model .gaps ]} .\n "
847- f"Effective positions (N={ len (model .v_idx )} ) are:\n "
848- f"{ [int (v_pos ) + 1 for v_pos in model .v_idx ]} " )
854+ print (
855+ f"Following positions are frequent gap positions "
856+ f"in the MSA and cannot be considered for effective "
857+ f"modeling, i.e., substitutions at these positions "
858+ f"are removed as these would be predicted with "
859+ f"wild-type fitness:"
860+ f"\n { [int (gap ) + 1 for gap in model .gaps ]} .\n "
861+ f"Effective positions (N={ len (model .v_idx )} ) are:\n "
862+ f"{ [int (v_pos ) + 1 for v_pos in model .v_idx ]} "
863+ )
849864 xs , x_wt , variants , sequences , ys_true = gremlin_encoding (
850865 model , variants , sequences , ys_true ,
851866 shift_pos = 1 , substitution_sep = substitution_sep
@@ -920,11 +935,14 @@ def remove_gap_pos(
920935 Returns
921936 -----------
922937 variants_v
923- Variants with substitutions at valid sequence positions, i.e., at non-gap positions
938+ Variants with substitutions at valid sequence positions,
939+ i.e., at non-gap positions
924940 sequences_v
925- Sequences of variants with substitutions at valid sequence positions, i.e., at non-gap positions
941+ Sequences of variants with substitutions at valid sequence positions,
942+ i.e., at non-gap positions
926943 fitnesses_v
927- Fitness values of variants with substitutions at valid sequence positions, i.e., at non-gap positions
944+ Fitness values of variants with substitutions at valid sequence positions,
945+ i.e., at non-gap positions
928946 """
929947 variants_v , sequences_v , fitnesses_v = [], [], []
930948 valid = []
@@ -1029,12 +1047,12 @@ def llm_embedder(llm_dict, seqs):
10291047 np .shape (seqs )
10301048 #except np.shape error:
10311049 if list (llm_dict .keys ())[0 ] == 'esm1v' :
1032- x_llm_seqs = esm_tokenize_sequences (
1033- seqs , llm_dict ['esm1v' ]['llm_tokenizer' ], max_length = len (seqs [0 ])
1050+ x_llm_seqs , _attention_mask = esm_tokenize_sequences (
1051+ seqs , tokenizer = llm_dict ['esm1v' ]['llm_tokenizer' ], max_length = len (seqs [0 ])
10341052 )
10351053 elif list (llm_dict .keys ())[0 ] == 'prosst' :
10361054 x_llm_seqs = prosst_tokenize_sequences (
1037- seqs , llm_dict ['prosst' ]['llm_tokenizer' ], max_length = len ( seqs [ 0 ])
1055+ seqs , vocab = llm_dict ['prosst' ]['llm_vocab' ]
10381056 )
10391057 else :
10401058 raise SystemError (f"Unknown LLM dictionary input:\n { list (llm_dict .keys ())[0 ]} " )
@@ -1048,8 +1066,8 @@ def performance_ls_ts(
10481066 params_file : str ,
10491067 model_pickle_file : str | None = None ,
10501068 llm : str | None = None ,
1051- wt_seq : str | None = None ,
10521069 pdb_file : str | None = None ,
1070+ wt_seq : str | None = None ,
10531071 substitution_sep : str = '/' ,
10541072 label = False
10551073):
@@ -1091,32 +1109,58 @@ def performance_ls_ts(
10911109 test_sequences , test_variants , y_test = get_sequences_from_file (ts_fasta )
10921110
10931111 if ls_fasta is not None and ts_fasta is not None :
1094- train_sequences , train_variants , y_train = get_sequences_from_file (ls_fasta )
1095- x_train , train_variants , train_sequences , y_train , x_wt , _ , model_type = plmc_or_gremlin_encoding (
1096- train_variants , train_sequences , y_train , params_file , substitution_sep , threads
1112+ train_sequences , train_variants , y_train = get_sequences_from_file (
1113+ ls_fasta )
1114+ (
1115+ x_train , train_variants , train_sequences ,
1116+ y_train , x_wt , _ , model_type
1117+ ) = plmc_or_gremlin_encoding (
1118+ train_variants , train_sequences , y_train ,
1119+ params_file , substitution_sep , threads
10971120 )
10981121
1099- x_test , test_variants , test_sequences , y_test , * _ = plmc_or_gremlin_encoding (
1100- test_variants , test_sequences , y_test , params_file , substitution_sep , threads , verbose = False
1122+ (
1123+ x_test , test_variants , test_sequences , y_test , * _
1124+ ) = plmc_or_gremlin_encoding (
1125+ test_variants , test_sequences , y_test , params_file ,
1126+ substitution_sep , threads , verbose = False
11011127 )
11021128
11031129 print (f"\n Initial training set variants: { len (train_sequences )} . "
11041130 f"Remaining: { len (train_variants )} (after removing "
11051131 f"substitutions at gap positions).\n Initial test set "
1106- f"variants: { len (test_sequences )} . Remaining: { len (test_variants )} "
1107- f"(after removing substitutions at gap positions)."
1132+ f"variants: { len (test_sequences )} . Remaining: "
1133+ f"{ len (test_variants )} (after removing substitutions "
1134+ f"at gap positions)."
11081135 )
11091136 if llm == 'esm' :
11101137 llm_dict = esm_setup (train_sequences )
11111138 x_llm_test = llm_embedder (llm_dict , test_sequences )
11121139 elif llm == 'prosst' :
1113- llm_dict = prosst_setup (wt_seq , pdb_file , sequences = train_sequences )
1140+ if pdb_file is None :
1141+ raise SystemError (
1142+ "Running ProSST requires a PDB file input "
1143+ "for embedding sequences! Specify a PDB file "
1144+ "with the --pdb flag."
1145+ )
1146+ if wt_seq is None :
1147+ raise SystemError (
1148+ "Running ProSST requires a wild-type sequence "
1149+ "FASTA file input for embedding sequences! "
1150+ "Specify a FASTA file with the --wt flag."
1151+ )
1152+ pdb_seq = str (list (SeqIO .parse (pdb_file , "pdb-atom" ))[0 ].seq )
1153+ assert wt_seq == pdb_seq , (
1154+ f"Wild-type sequence is not matching PDB-extracted sequence:"
1155+ f"\n WT sequence:\n { wt_seq } \n PDB sequence:\n { pdb_seq } "
1156+ )
1157+ llm_dict = prosst_setup (
1158+ wt_seq , pdb_file , sequences = train_sequences )
11141159 x_llm_test = llm_embedder (llm_dict , test_sequences )
11151160 else :
11161161 llm_dict = None
11171162 x_llm_test = None
11181163 llm = ''
1119-
11201164 hybrid_model = DCALLMHybridModel (
11211165 x_train_dca = np .array (x_train ),
11221166 y_train = np .array (y_train ),
@@ -1132,11 +1176,19 @@ def performance_ls_ts(
11321176 print (f'Taking model from saved model (Pickle file): { model_pickle_file } ...' )
11331177 model , model_type = get_model_and_type (model_pickle_file )
11341178 if model_type != 'Hybrid' : # same as below in next elif
1135- x_test , test_variants , test_sequences , y_test , x_wt , * _ = plmc_or_gremlin_encoding (
1136- test_variants , test_sequences , y_test , model_pickle_file , substitution_sep , threads , False )
1179+ (
1180+ x_test , test_variants , test_sequences ,
1181+ y_test , x_wt , * _
1182+ ) = plmc_or_gremlin_encoding (
1183+ test_variants , test_sequences , y_test , model_pickle_file ,
1184+ substitution_sep , threads , False
1185+ )
11371186 y_test_pred = get_delta_e_statistical_model (x_test , x_wt )
11381187 else : # Hybrid model input requires params from plmc or GREMLIN model
1139- x_test , test_variants , test_sequences , y_test , * _ = plmc_or_gremlin_encoding (
1188+ (
1189+ x_test , test_variants , test_sequences ,
1190+ y_test , * _
1191+ ) = plmc_or_gremlin_encoding (
11401192 test_variants , test_sequences , y_test , params_file ,
11411193 substitution_sep , threads , False
11421194 )
0 commit comments