3737from sklearn .linear_model import Ridge
3838from sklearn .model_selection import GridSearchCV , train_test_split
3939from scipy .optimize import differential_evolution
40- from Bio import SeqIO , BiopythonParserWarning
41- warnings .filterwarnings (action = 'ignore' , category = BiopythonParserWarning )
4240
4341from pypef .utils .variant_data import (
4442 get_sequences_from_file , get_seqs_from_var_name ,
@@ -97,7 +95,7 @@ def __init__(
9795 self .llm_train_function = llm_model_input ['esm1v' ]['llm_train_function' ]
9896 self .llm_inference_function = llm_model_input ['esm1v' ]['llm_inference_function' ]
9997 self .llm_loss_function = llm_model_input ['esm1v' ]['llm_loss_function' ]
100- self .x_train_llm = llm_model_input ['esm1v' ]['x_llm_train ' ]
98+ self .x_train_llm = llm_model_input ['esm1v' ]['x_llm ' ]
10199 self .llm_attention_mask = llm_model_input ['esm1v' ]['llm_attention_mask' ]
102100 elif len (list (llm_model_input .keys ())) == 1 and list (llm_model_input .keys ())[0 ] == 'prosst' :
103101 self .llm_key = 'prosst'
@@ -107,7 +105,7 @@ def __init__(
107105 self .llm_train_function = llm_model_input ['prosst' ]['llm_train_function' ]
108106 self .llm_inference_function = llm_model_input ['prosst' ]['llm_inference_function' ]
109107 self .llm_loss_function = llm_model_input ['prosst' ]['llm_loss_function' ]
110- self .x_train_llm = llm_model_input ['prosst' ]['x_llm_train ' ]
108+ self .x_train_llm = llm_model_input ['prosst' ]['x_llm ' ]
111109 self .llm_attention_mask = llm_model_input ['prosst' ]['llm_attention_mask' ]
112110 self .input_ids = llm_model_input ['prosst' ]['input_ids' ]
113111 self .structure_input_ids = llm_model_input ['prosst' ]['structure_input_ids' ]
@@ -844,7 +842,8 @@ def plmc_or_gremlin_encoding(
844842 else :
845843 model , model_type = global_model , global_model_type
846844 else :
847- model , model_type = get_model_and_type (params_file , substitution_sep )
845+ model , model_type = get_model_and_type (
846+ params_file , substitution_sep )
848847 if model_type == 'PLMC' :
849848 xs , x_wt , variants , sequences , ys_true = plmc_encoding (
850849 model , variants , sequences , ys_true , threads , verbose
@@ -867,20 +866,25 @@ def plmc_or_gremlin_encoding(
867866 )
868867 else :
869868 raise SystemError (
870- f"Found a { model_type .lower ()} model as input. Please train a new "
871- f"hybrid model on the provided LS/TS datasets."
869+ f"Found a { model_type .lower ()} model as input. Please "
870+ f"train a new hybrid model on the provided LS/TS datasets."
872871 )
873872 assert len (xs ) == len (variants ) == len (sequences ) == len (ys_true )
874873 return xs , variants , sequences , ys_true , x_wt , model , model_type
875874
876875
877- def gremlin_encoding (gremlin : GREMLIN , variants , sequences , ys_true , shift_pos = 1 , substitution_sep = '/' ):
876+ def gremlin_encoding (gremlin : GREMLIN , variants , sequences , ys_true ,
877+ shift_pos = 1 , substitution_sep = '/' ):
878878 """
879879 Gets X and x_wt for DCA prediction: delta_Hamiltonian respectively
880880 delta_E = np.subtract(X, x_wt), with X = encoded sequences of variants.
881881 Also removes variants, sequences, and y_trues at MSA gap positions.
882882 """
883- variants , sequences , ys_true = np .atleast_1d (variants ), np .atleast_1d (sequences ), np .atleast_1d (ys_true )
883+ variants , sequences , ys_true = (
884+ np .atleast_1d (variants ),
885+ np .atleast_1d (sequences ),
886+ np .atleast_1d (ys_true )
887+ )
884888 variants , sequences , ys_true = remove_gap_pos (
885889 gremlin .gaps , variants , sequences , ys_true ,
886890 shift_pos = shift_pos , substitution_sep = substitution_sep
@@ -993,7 +997,8 @@ def generate_model_and_save_pkl(
993997 """
994998 wt_seq = get_wt_sequence (wt )
995999 variants_splitted = split_variants (variants , substitution_sep )
996- variants , ys_true , sequences = get_seqs_from_var_name (wt_seq , variants_splitted , ys_true )
1000+ variants , ys_true , sequences = get_seqs_from_var_name (
1001+ wt_seq , variants_splitted , ys_true )
9971002
9981003 xs , variants , sequences , ys_true , x_wt , _model , model_type = plmc_or_gremlin_encoding (
9991004 variants , sequences , ys_true , params_file , substitution_sep , threads )
@@ -1043,9 +1048,10 @@ def generate_model_and_save_pkl(
10431048
10441049
10451050def llm_embedder (llm_dict , seqs ):
1046- #try:
1047- np .shape (seqs )
1048- #except np.shape error:
1051+ try :
1052+ np .shape (seqs )
1053+ except ValueError :
1054+ raise SystemError ("Unequal input sequence length detected!" )
10491055 if list (llm_dict .keys ())[0 ] == 'esm1v' :
10501056 x_llm_seqs , _attention_mask = esm_tokenize_sequences (
10511057 seqs , tokenizer = llm_dict ['esm1v' ]['llm_tokenizer' ], max_length = len (seqs [0 ])
@@ -1069,7 +1075,8 @@ def performance_ls_ts(
10691075 pdb_file : str | None = None ,
10701076 wt_seq : str | None = None ,
10711077 substitution_sep : str = '/' ,
1072- label = False
1078+ label = False ,
1079+ device : str | None = None
10731080):
10741081 """
10751082 Description
@@ -1137,23 +1144,6 @@ def performance_ls_ts(
11371144 llm_dict = esm_setup (train_sequences )
11381145 x_llm_test = llm_embedder (llm_dict , test_sequences )
11391146 elif llm == 'prosst' :
1140- if pdb_file is None :
1141- raise SystemError (
1142- "Running ProSST requires a PDB file input "
1143- "for embedding sequences! Specify a PDB file "
1144- "with the --pdb flag."
1145- )
1146- if wt_seq is None :
1147- raise SystemError (
1148- "Running ProSST requires a wild-type sequence "
1149- "FASTA file input for embedding sequences! "
1150- "Specify a FASTA file with the --wt flag."
1151- )
1152- pdb_seq = str (list (SeqIO .parse (pdb_file , "pdb-atom" ))[0 ].seq )
1153- assert wt_seq == pdb_seq , (
1154- f"Wild-type sequence is not matching PDB-extracted sequence:"
1155- f"\n WT sequence:\n { wt_seq } \n PDB sequence:\n { pdb_seq } "
1156- )
11571147 llm_dict = prosst_setup (
11581148 wt_seq , pdb_file , sequences = train_sequences )
11591149 x_llm_test = llm_embedder (llm_dict , test_sequences )
@@ -1173,6 +1163,7 @@ def performance_ls_ts(
11731163 save_model_to_dict_pickle (hybrid_model , model_name )
11741164
11751165 elif ts_fasta is not None and model_pickle_file is not None and params_file is not None :
1166+ # # no LS provided --> statistical modeling / no ML
11761167 print (f'Taking model from saved model (Pickle file): { model_pickle_file } ...' )
11771168 model , model_type = get_model_and_type (model_pickle_file )
11781169 if model_type != 'Hybrid' : # same as below in next elif
@@ -1193,33 +1184,60 @@ def performance_ls_ts(
11931184 substitution_sep , threads , False
11941185 )
11951186 if model .llm_model_input is not None :
1196- if list (model .llm_model_input .keys ())[0 ] == 'esm1v' :
1197- pass
1187+ print (f"Found hybrid model with LLM { list (model .llm_model_input .keys ())[0 ]} ..." )
1188+ x_llm_test = llm_embedder (llm_dict , test_sequences )
1189+ model .hybrid_prediction (x_test , x_llm_test )
11981190 else :
11991191 y_test_pred = model .hybrid_prediction (x_test )
12001192
12011193 elif ts_fasta is not None and model_pickle_file is None : # no LS provided --> statistical modeling / no ML
1202- print (f' No learning set provided, falling back to statistical DCA model: '
1203- f' no adjustments of individual hybrid model parameters (beta_1 and beta_2).' )
1194+ print (f" No learning set provided, falling back to statistical DCA model: "
1195+ f" no adjustments of individual hybrid model parameters (\" beta's \" )." )
12041196 test_sequences , test_variants , y_test = get_sequences_from_file (ts_fasta )
1205- x_test , test_variants , test_sequences , y_test , x_wt , model , model_type = plmc_or_gremlin_encoding (
1206- test_variants , test_sequences , y_test , params_file , substitution_sep , threads
1197+ (
1198+ x_test , test_variants , test_sequences ,
1199+ y_test , x_wt , model , model_type
1200+ ) = plmc_or_gremlin_encoding (
1201+ test_variants , test_sequences , y_test ,
1202+ params_file , substitution_sep , threads
12071203 )
1208-
12091204 print (f"Initial test set variants: { len (test_sequences )} . "
12101205 f"Remaining: { len (test_variants )} (after removing "
12111206 f"substitutions at gap positions)." )
1212-
12131207 y_test_pred = get_delta_e_statistical_model (x_test , x_wt )
1214- save_model_to_dict_pickle (model , model_type , None , None , spearmanr (y_test , y_test_pred )[0 ], None )
1208+ if llm == 'esm' :
1209+ llm_dict = esm_setup (test_sequences )
1210+ x_llm_test = llm_embedder (llm_dict , test_sequences )
1211+ y_test_pred_llm = llm_dict ['esm1v' ]['llm_inference_function' ](
1212+ xs = get_batches (x_llm_test , batch_size = 1 , dtype = int ),
1213+ attention_mask = llm_dict ['esm1v' ]['llm_attention_mask' ],
1214+ model = llm_dict ['esm1v' ]['llm_base_model' ],
1215+ device = device
1216+ ).cpu ()
1217+ plot_y_true_vs_y_pred (
1218+ np .array (y_test ), np .array (y_test_pred_llm ), np .array (test_variants ),
1219+ label = label , hybrid = True , name = f'ESM1v_no_ML'
1220+ )
1221+ elif llm == 'prosst' :
1222+ llm_dict = prosst_setup (
1223+ wt_seq , pdb_file , sequences = test_sequences )
1224+ x_llm_test = llm_embedder (llm_dict , test_sequences )
1225+ y_test_pred_llm = llm_dict ['prosst' ]['llm_inference_function' ](
1226+ xs = x_llm_test ,
1227+ model = llm_dict ['prosst' ]['llm_base_model' ],
1228+ input_ids = llm_dict ['prosst' ]['input_ids' ],
1229+ attention_mask = llm_dict ['prosst' ]['llm_attention_mask' ],
1230+ structure_input_ids = llm_dict ['prosst' ]['structure_input_ids' ],
1231+ device = device
1232+ ).cpu ()
1233+ plot_y_true_vs_y_pred (
1234+ np .array (y_test ), np .array (y_test_pred_llm ), np .array (test_variants ),
1235+ label = label , hybrid = True , name = f'ProSST_no_ML'
1236+ )
1237+ save_model_to_dict_pickle (model , model_type )
12151238 model_type = f'{ model_type } _no_ML'
1216-
12171239 else :
1218- raise SystemError ('No Test Set given for performance estimation.' )
1219-
1220- spearman_rho = spearmanr (y_test , y_test_pred )
1221- print (f'Spearman Rho = { spearman_rho [0 ]:.3f} ' )
1222-
1240+ raise SystemError ('No test set given for performance estimation.' )
12231241 plot_y_true_vs_y_pred (
12241242 np .array (y_test ), np .array (y_test_pred ), np .array (test_variants ),
12251243 label = label , hybrid = True , name = model_type
0 commit comments