@@ -79,6 +79,7 @@ def __init__(
7979 alphas : np .ndarray | None = None ,
8080 parameter_range : list [tuple ] | None = None ,
8181 batch_size : int | None = None ,
82+ llm_train : bool = True ,
8283 device : str | None = None ,
8384 seed : int | None = None
8485 ):
@@ -135,6 +136,7 @@ def __init__(
135136 if batch_size is None :
136137 batch_size = 5
137138 self .batch_size = batch_size
139+ self .llm_train = llm_train
138140 (
139141 self .ridge_opt ,
140142 self .beta1 ,
@@ -408,7 +410,7 @@ def train_llm(self):
408410 input_ids = self .input_ids ,
409411 attention_mask = self .llm_attention_mask ,
410412 structure_input_ids = self .structure_input_ids ,
411- train = True ,
413+ train = False ,
412414 device = self .device
413415 )
414416 y_llm_ttrain = self .llm_inference_function (
@@ -417,7 +419,7 @@ def train_llm(self):
417419 input_ids = self .input_ids ,
418420 attention_mask = self .llm_attention_mask ,
419421 structure_input_ids = self .structure_input_ids ,
420- train = True ,
422+ train = False ,
421423 device = self .device
422424 )
423425 elif self .llm_key == 'esm1v' :
@@ -1206,7 +1208,11 @@ def performance_ls_ts(
12061208 print (f'Hybrid performance: { spearmanr (y_test , y_test_pred )} ' )
12071209 save_model_to_dict_pickle (hybrid_model , model_name )
12081210
1209- elif ts_fasta is not None and model_pickle_file is not None and params_file is not None :
1211+ elif (
1212+ ts_fasta is not None and
1213+ model_pickle_file is not None
1214+ and params_file is not None
1215+ ):
12101216 # # no LS provided --> statistical modeling / no ML
12111217 print (f'Taking model from saved model (Pickle file): { model_pickle_file } ...' )
12121218 model , model_type = get_model_and_type (model_pickle_file )
@@ -1233,8 +1239,9 @@ def performance_ls_ts(
12331239 model .hybrid_prediction (x_test , x_llm_test )
12341240 else :
12351241 y_test_pred = model .hybrid_prediction (x_test )
1236-
1237- elif ts_fasta is not None and model_pickle_file is None : # no LS provided --> statistical modeling / no ML
1242+
1243+ # no LS provided --> statistical modeling / no ML
1244+ elif ts_fasta is not None and model_pickle_file is None :
12381245 print (f"No learning set provided, falling back to statistical DCA model: "
12391246 f"no adjustments of individual hybrid model parameters (\" beta's\" )." )
12401247 test_sequences , test_variants , y_test = get_sequences_from_file (ts_fasta )
@@ -1354,7 +1361,8 @@ def predict_ps( # also predicting "pmult" dict directories
13541361 model , model_type = get_model_and_type (model_pickle_file )
13551362
13561363 if model_type == 'PLMC' or model_type == 'GREMLIN' :
1357- print (f'Found { model_type } model file. No hybrid model provided - falling back to a statistical DCA model...' )
1364+ print (f'Found { model_type } model file. No hybrid model provided - '
1365+ f'falling back to a statistical DCA model...' )
13581366
13591367 pmult = [
13601368 'Recomb_Double_Split' , 'Recomb_Triple_Split' , 'Recomb_Quadruple_Split' ,
@@ -1377,14 +1385,14 @@ def predict_ps( # also predicting "pmult" dict directories
13771385 substitution_sep = separator )
13781386 ys_pred = get_delta_e_statistical_model (x_test , x_wt )
13791387 else : # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
1380- x_test , _test_variants , * _ = plmc_or_gremlin_encoding (
1388+ x_test , _test_variants , test_sequences , * _ = plmc_or_gremlin_encoding (
13811389 variants , sequences , None , params_file ,
13821390 threads = threads , verbose = False , substitution_sep = separator
13831391 )
13841392 if model .llm_key is None :
13851393 ys_pred = model .hybrid_prediction (x_test )
13861394 else :
1387- sequences = [str (seq ) for seq in sequences ]
1395+ sequences = [str (seq ) for seq in test_sequences ]
13881396 x_llm_test = llm_embedder (model .llm_model_input , sequences )
13891397 ys_pred = model .hybrid_prediction (np .asarray (x_test ), np .asarray (x_llm_test ))
13901398 for k , y in enumerate (ys_pred ):
@@ -1404,6 +1412,7 @@ def predict_ps( # also predicting "pmult" dict directories
14041412
14051413 elif prediction_set is not None : # Predicting single FASTA file sequences
14061414 sequences , variants , _ = get_sequences_from_file (prediction_set )
1415+ print (len (sequences ), len (variants ))
14071416 # NaNs are already being removed by the called function
14081417 if model_type != 'Hybrid' : # statistical DCA model
14091418 xs , variants , _ , _ , x_wt , * _ = plmc_or_gremlin_encoding (
@@ -1412,13 +1421,16 @@ def predict_ps( # also predicting "pmult" dict directories
14121421 )
14131422 ys_pred = get_delta_e_statistical_model (xs , x_wt )
14141423 else : # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
1415- xs , variants , * _ = plmc_or_gremlin_encoding (
1424+ print (len (variants ))
1425+ xs , variants , sequences , * _ = plmc_or_gremlin_encoding (
14161426 variants , sequences , None , params_file ,
14171427 threads = threads , verbose = True , substitution_sep = separator
14181428 )
1429+ print ('xs len' , len (xs ), len (variants ))
14191430 if model .llm_key is None :
14201431 ys_pred = model .hybrid_prediction (xs )
14211432 else :
1433+ sequences = [str (seq ) for seq in sequences ]
14221434 xs_llm = llm_embedder (model .llm_model_input , sequences )
14231435 ys_pred = model .hybrid_prediction (np .asarray (xs ), np .asarray (xs_llm ))
14241436 assert len (xs ) == len (variants ) == len (xs_llm ) == len (ys_pred )
@@ -1434,7 +1446,7 @@ def predict_ps( # also predicting "pmult" dict directories
14341446def predict_directed_evolution (
14351447 encoder : str ,
14361448 variant : str ,
1437- sequence : str ,
1449+ variant_sequence : str ,
14381450 hybrid_model_data_pkl : str
14391451) -> Union [str , list ]:
14401452 """
@@ -1452,27 +1464,36 @@ def predict_directed_evolution(
14521464
14531465 if model_type != 'Hybrid' : # statistical DCA model
14541466 xs , variant , _ , _ , x_wt , * _ = plmc_or_gremlin_encoding (
1455- variant , sequence , None , encoder , verbose = False , use_global_model = True )
1467+ variant , variant_sequence , None , encoder ,
1468+ verbose = False , use_global_model = True )
14561469 if not list (xs ):
14571470 return 'skip'
14581471 y_pred = get_delta_e_statistical_model (xs , x_wt )
1459- else : # model_type == 'Hybrid': Hybrid model input requires params from PLMC or GREMLIN model plus optional LLM input
1460- xs , variant , * _ = plmc_or_gremlin_encoding (
1461- variant , sequence , None , encoder , verbose = False , use_global_model = True
1472+ else : # model_type == 'Hybrid': Hybrid model input requires params
1473+ #from PLMC or GREMLIN model plus optional LLM input
1474+ print (variant , variant_sequence )
1475+ xs , variant , variant_sequence , * _ = plmc_or_gremlin_encoding (
1476+ variant , variant_sequence , None , encoder ,
1477+ verbose = False , use_global_model = True
14621478 )
1479+ print (variant_sequence )
14631480 if not list (xs ):
14641481 return 'skip'
14651482 if model .llm_model_input is None :
14661483 x_llm = None
14671484 else :
1468- x_llm = llm_embedder (model .llm_model_input , sequence )
1485+ x_llm = llm_embedder (model .llm_model_input , variant_sequence )
14691486 try :
1487+ print (np .shape (xs ), np .shape (x_llm ), np .atleast_2d (x_llm ))
1488+ #exit()
14701489 y_pred = model .hybrid_prediction (np .atleast_2d (xs ), np .atleast_2d (x_llm ))[0 ]
1471- except ValueError :
1472- raise SystemError (
1473- "Probably a different model was used for encoding than for modeling; "
1474- "e.g. using a HYBRIDgremlin model in combination with parameters taken from a PLMC file."
1475- )
1490+ except ValueError as e :
1491+ raise e # TODO: Check sequences / mutations
1492+ # raise SystemError(
1493+ # "Probably a different model was used for encoding than "
1494+ # "for modeling; e.g. using a HYBRIDgremlin model in "
1495+ # "combination with parameters taken from a PLMC file."
1496+ # )
14761497 y_pred = float (y_pred )
14771498
14781499 return [(y_pred , variant [0 ][1 :])]
0 commit comments