@@ -609,69 +609,6 @@ def hybrid_prediction(
609609
610610 return self .beta1 * y_dca + self .beta2 * y_ridge + self .beta3 * y_llm + self .beta4 * y_llm_lora
611611
612- def split_performance (
613- self ,
614- train_size : float = 0.8 ,
615- n_runs : int = 10 ,
616- seed : int = 42 ,
617- save_model : bool = False
618- ) -> dict :
619- """
620- TODO: Update
621- Estimates performance of the model.
622-
623- Parameters
624- ----------
625- train_size : int or float (default=0.8)
626- Number of samples in the training dataset
627- or fraction of full dataset used for training.
628- n_runs : int (default=10)
629- Number of different splits to perform.
630- seed : int (default=42)
631- Seed for random generator.
632- save_model : bool (default=False)
633- If True, model is saved using pickle, else not.
634-
635- Returns
636- -------
637- data : dict
638- Contains information about hybrid model parameters
639- and performance results.
640- """
641- data = {}
642- np .random .seed (seed )
643-
644- for r , random_state in enumerate (np .random .randint (100 , size = n_runs )):
645- x_train , x_test , y_train , y_test = train_test_split (
646- self .X , self .y , train_size = train_size , random_state = random_state )
647- beta_1 , beta_2 , reg = self .settings (x_train , y_train )
648- if beta_2 == 0.0 :
649- alpha = np .nan
650- else :
651- if save_model :
652- pickle .dumps (reg )
653- alpha = reg .alpha
654- data .update (
655- {f'{ len (y_train )} _{ r } ' :
656- {
657- 'no_run' : r ,
658- 'n_y_train' : len (y_train ),
659- 'n_y_test' : len (y_test ),
660- 'rnd_state' : random_state ,
661- 'spearman_rho' : self .spearmanr (
662- y_test , self .hybrid_prediction (
663- x_test , reg , beta_1 , beta_2
664- )
665- ),
666- 'beta_1' : beta_1 ,
667- 'beta_2' : beta_2 ,
668- 'alpha' : alpha
669- }
670- }
671- )
672-
673- return data
674-
675612 def ls_ts_performance (self ):
676613 beta_1 , beta_2 , reg = self .settings (
677614 x_train = self .x_train ,
@@ -744,49 +681,13 @@ def train_and_test(
744681 test_spearman_r = None
745682 return beta_1 , beta_2 , reg , self ._spearmanr_dca , test_spearman_r
746683
747- def get_train_sizes (self ) -> np .ndarray :
748- """
749- Generates a list of train sizes to perform low-n with.
750-
751- Returns
752- -------
753- Numpy array of train sizes up to 80% (i.e. 0.8 * N_variants).
754- """
755- eighty_percent = int (len (self .y ) * 0.8 )
756-
757- train_sizes = np .sort (np .concatenate ([
758- np .arange (15 , 50 , 5 ), np .arange (50 , 100 , 10 ),
759- np .arange (100 , 150 , 20 ), [160 , 200 , 250 , 300 , eighty_percent ],
760- np .arange (400 , 1100 , 100 )
761- ]))
762-
763- idx_max = np .where (train_sizes >= eighty_percent )[0 ][0 ] + 1
764- return train_sizes [:idx_max ]
765-
766- def run (
767- self ,
768- train_sizes : list = None ,
769- n_runs : int = 10
770- ) -> dict :
771- """
772-
773- Returns
774- ----------
775- data: dict
776- Performances of the split with size of the
777- training set = train_size and size of the
778- test set = N_variants - train_size.
779- """
780- data = {}
781- for t , train_size in enumerate (train_sizes ):
782- print (f'{ t + 1 } /{ len (train_sizes )} :{ train_size } ' )
783- data .update (self .split_performance (train_size = train_size , n_runs = n_runs ))
784- return data
785684
786685
787- """
788- Below: Some helper functions that call or are dependent on the DCALLMHybridModel class.
789- """
686+ """
687+ ###########################################################################################
688+ # Below: Some helper functions that call or are dependent on the DCALLMHybridModel class. #
689+ ###########################################################################################
690+ """
790691
791692
792693def check_model_type (model : dict | DCALLMHybridModel | PLMC | GREMLIN ):
@@ -940,11 +841,11 @@ def plmc_or_gremlin_encoding(
940841 elif model_type == 'GREMLIN' :
941842 if verbose :
942843 print (f"Following positions are frequent gap positions in the MSA "
943- f"and cannot be considered for effective modeling, i.e., "
944- f"substitutions at these positions are removed as these would be "
945- f"predicted with wild-type fitness:\n { [int (gap ) + 1 for gap in model .gaps ]} .\n "
946- f"Effective positions (N={ len (model .v_idx )} ) are:\n "
947- f"{ [int (v_pos ) + 1 for v_pos in model .v_idx ]} " )
844+ f"and cannot be considered for effective modeling, i.e., "
845+ f"substitutions at these positions are removed as these would be "
846+ f"predicted with wild-type fitness:\n { [int (gap ) + 1 for gap in model .gaps ]} .\n "
847+ f"Effective positions (N={ len (model .v_idx )} ) are:\n "
848+ f"{ [int (v_pos ) + 1 for v_pos in model .v_idx ]} " )
948849 xs , x_wt , variants , sequences , ys_true = gremlin_encoding (
949850 model , variants , sequences , ys_true ,
950851 shift_pos = 1 , substitution_sep = substitution_sep
@@ -987,7 +888,7 @@ def plmc_encoding(plmc: PLMC, variants, sequences, ys_true, threads=1, verbose=F
987888 wt_name = target_seq [0 ] + str (index [0 ]) + target_seq [0 ]
988889 if verbose :
989890 print (f"Using to-self-substitution '{ wt_name } ' as wild type reference. "
990- f"Encoding variant sequences. This might take some time..." )
891+ f"Encoding variant sequences. This might take some time..." )
991892 x_wt = get_encoded_sequence (wt_name , plmc )
992893 if threads > 1 :
993894 # Hyperthreading, NaNs are already being removed by the called function
@@ -1123,6 +1024,22 @@ def generate_model_and_save_pkl(
11231024 save_model_to_dict_pickle (hybrid_model , model_name , beta_1 , beta_2 , test_spearman_r , reg )
11241025
11251026
1027+ def llm_embedder (llm_dict , seqs ):
1028+ #try:
1029+ np .shape (seqs )
1030+ #except np.shape error:
1031+ if list (llm_dict .keys ())[0 ] == 'esm1v' :
1032+ x_llm_seqs = esm_tokenize_sequences (
1033+ seqs , llm_dict ['esm1v' ]['llm_tokenizer' ], max_length = len (seqs [0 ])
1034+ )
1035+ elif list (llm_dict .keys ())[0 ] == 'prosst' :
1036+ x_llm_seqs = prosst_tokenize_sequences (
1037+ seqs , llm_dict ['prosst' ]['llm_tokenizer' ], max_length = len (seqs [0 ])
1038+ )
1039+ else :
1040+ raise SystemError (f"Unknown LLM dictionary input:\n { list (llm_dict .keys ())[0 ]} " )
1041+ return x_llm_seqs
1042+
11261043
11271044def performance_ls_ts (
11281045 ls_fasta : str | None ,
@@ -1188,25 +1105,18 @@ def performance_ls_ts(
11881105 f"substitutions at gap positions).\n Initial test set "
11891106 f"variants: { len (test_sequences )} . Remaining: { len (test_variants )} "
11901107 f"(after removing substitutions at gap positions)."
1191- )
1192- print ('LLM:' , llm )
1108+ )
11931109 if llm == 'esm' :
11941110 llm_dict = esm_setup (train_sequences )
1195- print ('XX' , llm_dict )
1196- x_llm_test = esm_tokenize_sequences (
1197- test_sequences , llm_dict ['esm1v' ]['llm_tokenizer' ], max_length = len (test_sequences [0 ])
1198- )
1111+ x_llm_test = llm_embedder (llm_dict , test_sequences )
11991112 elif llm == 'prosst' :
12001113 llm_dict = prosst_setup (wt_seq , pdb_file , sequences = train_sequences )
1201- x_llm_test = prosst_tokenize_sequences (
1202- test_sequences , llm_dict ['prosst' ]['llm_tokenizer' ], max_length = len (test_sequences [0 ])
1203- )
1114+ x_llm_test = llm_embedder (llm_dict , test_sequences )
12041115 else :
12051116 llm_dict = None
12061117 x_llm_test = None
12071118 llm = ''
12081119
1209-
12101120 hybrid_model = DCALLMHybridModel (
12111121 x_train_dca = np .array (x_train ),
12121122 y_train = np .array (y_train ),
@@ -1245,13 +1155,11 @@ def performance_ls_ts(
12451155 )
12461156
12471157 print (f"Initial test set variants: { len (test_sequences )} . "
1248- f"Remaining: { len (test_variants )} (after removing "
1249- f"substitutions at gap positions)." )
1158+ f"Remaining: { len (test_variants )} (after removing "
1159+ f"substitutions at gap positions)." )
12501160
12511161 y_test_pred = get_delta_e_statistical_model (x_test , x_wt )
1252-
12531162 save_model_to_dict_pickle (model , model_type , None , None , spearmanr (y_test , y_test_pred )[0 ], None )
1254-
12551163 model_type = f'{ model_type } _no_ML'
12561164
12571165 else :
@@ -1332,18 +1240,9 @@ def predict_ps( # also predicting "pmult" dict directories
13321240 model , model_type = get_model_and_type (model_pickle_file )
13331241
13341242 if model_type == 'PLMC' or model_type == 'GREMLIN' :
1335- print (f'No hybrid model provided - falling back to a statistical DCA model.' )
1243+ print (f'Found { model_type } model file. No hybrid model provided - falling back to a statistical DCA model.. .' )
13361244 elif model_type == 'Hybrid' :
1337- beta_1 , beta_2 , reg = model .beta_1 , model .beta_2 , model .regressor
1338- if reg is None :
1339- alpha_ = 'None'
1340- else :
1341- alpha_ = f'{ reg .alpha :.3f} '
1342- print (
1343- f'Individual model weights and regressor hyperparameters:\n '
1344- f'Hybrid model individual model contributions: Beta1 (DCA): { beta_1 :.3f} , '
1345- f'Beta2 (ML): { beta_2 :.3f} (regressor: Ridge(alpha={ alpha_ } )).'
1346- )
1245+ print (f'Found hybrid model...' )
13471246
13481247 pmult = [
13491248 'Recomb_Double_Split' , 'Recomb_Triple_Split' , 'Recomb_Quadruple_Split' ,
@@ -1365,13 +1264,16 @@ def predict_ps( # also predicting "pmult" dict directories
13651264 variants , sequences , None , model , threads = threads , verbose = False ,
13661265 substitution_sep = separator )
13671266 ys_pred = get_delta_e_statistical_model (x_test , x_wt )
1368- else : # Hybrid model input requires params from plmc or GREMLIN model
1369- ##encoding_model, encoding_model_type = get_model_and_type(params_file)
1267+ else : # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
13701268 x_test , _test_variants , * _ = plmc_or_gremlin_encoding (
13711269 variants , sequences , None , params_file ,
13721270 threads = threads , verbose = False , substitution_sep = separator
13731271 )
1374- ys_pred = model .hybrid_prediction (x_test , reg , beta_1 , beta_2 )
1272+ if model .llm_model_input is None :
1273+ ys_pred = model .hybrid_prediction (x_test )
1274+ else :
1275+ x_llm_test = llm_embedder (model .llm_model_input , sequences )
1276+ ys_pred = model .hybrid_prediction (np .asarray (x_test ), np .asarray (x_llm_test ))
13751277 for k , y in enumerate (ys_pred ):
13761278 all_y_v_pred .append ((ys_pred [k ], variants [k ]))
13771279 if negative : # sort by fitness value
@@ -1395,13 +1297,17 @@ def predict_ps( # also predicting "pmult" dict directories
13951297 variants , sequences , None , params_file ,
13961298 threads = threads , verbose = False , substitution_sep = separator )
13971299 ys_pred = get_delta_e_statistical_model (xs , x_wt )
1398- else : # Hybrid model input requires params from plmc or GREMLIN model
1300+ else : # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
13991301 xs , variants , * _ = plmc_or_gremlin_encoding (
14001302 variants , sequences , None , params_file ,
14011303 threads = threads , verbose = True , substitution_sep = separator
14021304 )
1403- ys_pred = model .hybrid_prediction (xs , reg , beta_1 , beta_2 )
1404- assert len (xs ) == len (variants )
1305+ if model .llm_model_input is None :
1306+ ys_pred = model .hybrid_prediction (xs )
1307+ else :
1308+ xs_llm = llm_embedder (model .llm_model_input , sequences )
1309+ ys_pred = model .hybrid_prediction (np .asarray (xs ), np .asarray (xs_llm ))
1310+ assert len (xs ) == len (variants ) == len (xs_llm ) == len (ys_pred )
14051311 y_v_pred = zip (ys_pred , variants )
14061312 y_v_pred = sorted (y_v_pred , key = lambda x : x [0 ], reverse = True )
14071313 predictions_out (
@@ -1436,14 +1342,18 @@ def predict_directed_evolution(
14361342 if not list (xs ):
14371343 return 'skip'
14381344 y_pred = get_delta_e_statistical_model (xs , x_wt )
1439- else : # model_type == 'Hybrid': Hybrid model input requires params from PLMC or GREMLIN model
1345+ else : # model_type == 'Hybrid': Hybrid model input requires params from PLMC or GREMLIN model plus optional LLM input
14401346 xs , variant , * _ = plmc_or_gremlin_encoding (
14411347 variant , sequence , None , encoder , verbose = False , use_global_model = True
14421348 )
14431349 if not list (xs ):
14441350 return 'skip'
1351+ if model .llm_model_input is None :
1352+ x_llm = None
1353+ else :
1354+ x_llm = llm_embedder (model .llm_model_input , sequence )
14451355 try :
1446- y_pred = model .hybrid_prediction (np .atleast_2d (xs ), model . regressor , model . beta_1 , model . beta_2 )[0 ]
1356+ y_pred = model .hybrid_prediction (np .atleast_2d (xs ), np . atleast_2d ( x_llm ) )[0 ]
14471357 except ValueError :
14481358 raise SystemError (
14491359 "Probably a different model was used for encoding than for modeling; "
0 commit comments