4848from pypef .utils .plot import plot_y_true_vs_y_pred
4949import pypef .dca .gremlin_inference
5050from pypef .dca .gremlin_inference import GREMLIN , get_delta_e_statistical_model
51- from pypef .llm .esm_lora_tune import get_batches
51+ from pypef .llm .esm_lora_tune import esm_tokenize_sequences , get_batches , esm_setup
52+ from pypef .llm .prosst_lora_tune import prosst_setup , prosst_tokenize_sequences
5253
5354# sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and
5455# will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function
@@ -396,16 +397,7 @@ def train_llm(self):
396397 #x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=self.batch_size, dtype=int)
397398 if self .llm_key == 'prosst' :
398399 y_llm_ttest = self .llm_inference_function (
399- xs = self .x_llm_ttest ,
400- model = self .llm_base_model ,
401- input_ids = self .input_ids ,
402- attention_mask = self .llm_attention_mask ,
403- structure_input_ids = self .structure_input_ids ,
404- train = True ,
405- device = self .device
406- )
407- y_llm_ttrain = self .llm_inference_function (
408- xs = self .x_llm_ttrain ,
400+ x_sequences = self .x_llm_ttest ,
409401 model = self .llm_base_model ,
410402 input_ids = self .input_ids ,
411403 attention_mask = self .llm_attention_mask ,
@@ -451,17 +443,8 @@ def train_llm(self):
451443 device = self .device ,
452444 #seed=self.seed
453445 )
454- y_llm_lora_ttrain = self .llm_inference_function (
455- xs = self .x_llm_ttrain ,
456- model = self .llm_model ,
457- input_ids = self .input_ids ,
458- attention_mask = self .llm_attention_mask ,
459- structure_input_ids = self .structure_input_ids ,
460- train = True ,
461- device = self .device
462- )
463446 y_llm_lora_ttest = self .llm_inference_function (
464- xs = self .x_llm_ttest ,
447+ x_sequences = self .x_llm_ttest ,
465448 model = self .llm_model ,
466449 input_ids = self .input_ids ,
467450 attention_mask = self .llm_attention_mask ,
@@ -575,10 +558,10 @@ def hybrid_prediction(
575558 if self .llm_attention_mask is not None :
576559 print ('No LLM input for hybrid prediction but the model '
577560 'has been trained using an LLM model input.. '
578- 'Using only DCA for hybrid prediction .. This can lead '
561+ 'Using only DCA for hybridprediction .. This can lead '
579562 'to unwanted prediction behavior if the hybrid model '
580563 'is trained including an LLM...' )
581- return self .beta1 * y_dca + self .beta2 * y_ridge
564+ return self .beta1 * y_dca + self .beta2
582565
583566 else :
584567 if self .llm_key == 'prosst' :
@@ -601,7 +584,7 @@ def hybrid_prediction(
601584 #desc='Infering LoRA-tuned model',
602585 device = self .device ).detach ().cpu ().numpy ()
603586 elif self .llm_key == 'esm1v' :
604- x_llm_b = get_batches (x_llm , batch_size = 1 , dtype = int )
587+ x_llm_b = get_batches (x_llm , batch_size = self . batch_size , dtype = int )
605588 y_llm = self .llm_inference_function (
606589 x_llm_b ,
607590 self .llm_attention_mask ,
@@ -615,6 +598,13 @@ def hybrid_prediction(
615598 #desc='Infering LoRA-tuned model',
616599 device = self .device ).detach ().cpu ().numpy ()
617600
601+
602+ y_dca , y_ridge , y_llm , y_llm_lora = (
603+ reduce_by_batch_modulo (y_dca , batch_size = self .batch_size ),
604+ reduce_by_batch_modulo (y_ridge , batch_size = self .batch_size ),
605+ reduce_by_batch_modulo (y_llm , batch_size = self .batch_size ),
606+ reduce_by_batch_modulo (y_llm_lora , batch_size = self .batch_size )
607+ )
618608 return self .beta1 * y_dca + self .beta2 * y_ridge + self .beta3 * y_llm + self .beta4 * y_llm_lora
619609
620610 def split_performance (
@@ -1131,12 +1121,16 @@ def generate_model_and_save_pkl(
11311121 save_model_to_dict_pickle (hybrid_model , model_name , beta_1 , beta_2 , test_spearman_r , reg )
11321122
11331123
1124+
11341125def performance_ls_ts (
11351126 ls_fasta : str | None ,
11361127 ts_fasta : str | None ,
11371128 threads : int ,
11381129 params_file : str ,
11391130 model_pickle_file : str | None = None ,
1131+ llm : str | None = None ,
1132+ wt_seq : str | None = None ,
1133+ pdb_file : str | None = None ,
11401134 substitution_sep : str = '/' ,
11411135 label = False
11421136):
@@ -1194,30 +1188,35 @@ def performance_ls_ts(
11941188 f"(after removing substitutions at gap positions)."
11951189 )
11961190
1191+ if llm == 'esm' :
1192+ llm_dict = esm_setup (train_sequences )
1193+ x_llm_test = esm_tokenize_sequences (
1194+ test_sequences , llm_dict ['llm_tokenizer' ], max_length = len (test_sequences [0 ])
1195+ )
1196+ elif llm == 'prosst' :
1197+ llm_dict = prosst_setup (wt_seq , pdb_file , sequences = train_sequences )
1198+ x_llm_test = prosst_tokenize_sequences (
1199+ test_sequences , llm_dict ['llm_tokenizer' ], max_length = len (test_sequences [0 ])
1200+ )
1201+ else :
1202+ llm_dict = None
1203+ x_llm_test = None
1204+ llm = ''
1205+
1206+
11971207 hybrid_model = DCALLMHybridModel (
1198- x_train = np .array (x_train ),
1208+ x_train_dca = np .array (x_train ),
11991209 y_train = np .array (y_train ),
1200- x_test = np .array (x_test ),
1201- y_test = np .array (y_test ),
1210+ llm_model_input = llm_dict ,
12021211 x_wt = x_wt
12031212 )
1204- model_name = f'HYBRID{ model_type .lower ()} '
1213+ model_name = f'HYBRID{ model_type .lower ()} { llm . lower () } '
12051214
1206- spearman_r , reg , beta_1 , beta_2 = hybrid_model .ls_ts_performance ()
1207- ys_pred = hybrid_model .hybrid_prediction (np .array (x_test ), reg , beta_1 , beta_2 )
1215+ y_test_pred = hybrid_model .hybrid_prediction (np .array (x_test ), x_llm_test )
12081216
1209- if reg is None :
1210- alpha_ = 'None'
1211- else :
1212- alpha_ = f'{ reg .alpha :.3f} '
1213- print (
1214- f'Individual model weights and regressor hyperparameters:\n '
1215- f'Hybrid model individual model contributions: Beta1 (DCA): '
1216- f'{ beta_1 :.3f} , Beta2 (ML): { beta_2 :.3f} (regressor: '
1217- f'Ridge(alpha={ alpha_ } ))\n Testing performance...'
1218- )
1217+ print (f'Hybrid performance: { spearmanr (y_test , y_test_pred )} ' )
12191218
1220- save_model_to_dict_pickle (hybrid_model , model_name , beta_1 , beta_2 , spearman_r , reg )
1219+ save_model_to_dict_pickle (hybrid_model , model_name )
12211220
12221221 elif ts_fasta is not None and model_pickle_file is not None and params_file is not None :
12231222 print (f'Taking model from saved model (Pickle file): { model_pickle_file } ...' )
@@ -1227,18 +1226,18 @@ def performance_ls_ts(
12271226 if model_type != 'Hybrid' : # same as below in next elif
12281227 x_test , test_variants , test_sequences , y_test , x_wt , * _ = plmc_or_gremlin_encoding (
12291228 test_variants , test_sequences , y_test , model_pickle_file , substitution_sep , threads , False )
1230- ys_pred = get_delta_e_statistical_model (x_test , x_wt )
1229+ y_test_pred = get_delta_e_statistical_model (x_test , x_wt )
12311230 else : # Hybrid model input requires params from plmc or GREMLIN model
1232- beta_1 , beta_2 , reg = model .beta_1 , model .beta_2 , model .regressor
1231+ # beta_1, beta_2, reg = model.beta_1, model.beta_2, model.regressor
12331232 x_test , test_variants , test_sequences , y_test , * _ = plmc_or_gremlin_encoding (
12341233 test_variants , test_sequences , y_test , params_file ,
12351234 substitution_sep , threads , False
12361235 )
1237- ys_pred = model .hybrid_prediction (x_test , reg , beta_1 , beta_2 )
1236+ y_test_pred = model .hybrid_prediction (x_test )
12381237
12391238 elif ts_fasta is not None and model_pickle_file is None : # no LS provided --> statistical modeling / no ML
12401239 print (f'No learning set provided, falling back to statistical DCA model: '
1241- f'no adjustments of individual hybrid model parameters (beta_1 and beta_2).' )
1240+ f'no adjustments of individual hybrid model parameters (beta_1 and beta_2).' )
12421241 test_sequences , test_variants , y_test = get_sequences_from_file (ts_fasta )
12431242 x_test , test_variants , test_sequences , y_test , x_wt , model , model_type = plmc_or_gremlin_encoding (
12441243 test_variants , test_sequences , y_test , params_file , substitution_sep , threads
@@ -1248,20 +1247,20 @@ def performance_ls_ts(
12481247 f"Remaining: { len (test_variants )} (after removing "
12491248 f"substitutions at gap positions)." )
12501249
1251- ys_pred = get_delta_e_statistical_model (x_test , x_wt )
1250+ y_test_pred = get_delta_e_statistical_model (x_test , x_wt )
12521251
1253- save_model_to_dict_pickle (model , model_type , None , None , spearmanr (y_test , ys_pred )[0 ], None )
1252+ save_model_to_dict_pickle (model , model_type , None , None , spearmanr (y_test , y_test_pred )[0 ], None )
12541253
12551254 model_type = f'{ model_type } _no_ML'
12561255
12571256 else :
12581257 raise SystemError ('No Test Set given for performance estimation.' )
12591258
1260- spearman_rho = spearmanr (y_test , ys_pred )
1259+ spearman_rho = spearmanr (y_test , y_test_pred )
12611260 print (f'Spearman Rho = { spearman_rho [0 ]:.3f} ' )
12621261
12631262 plot_y_true_vs_y_pred (
1264- np .array (y_test ), np .array (ys_pred ), np .array (test_variants ),
1263+ np .array (y_test ), np .array (y_test_pred ), np .array (test_variants ),
12651264 label = label , hybrid = True , name = model_type
12661265 )
12671266
0 commit comments