@@ -1417,7 +1417,6 @@ def predict_ps( # also predicting "pmult" dict directories
14171417
14181418 elif prediction_set is not None : # Predicting single FASTA file sequences
14191419 sequences , variants , _ = get_sequences_from_file (prediction_set )
1420- print (len (sequences ), len (variants ))
14211420 # NaNs are already being removed by the called function
14221421 if model_type != 'Hybrid' : # statistical DCA model
14231422 xs , variants , _ , _ , x_wt , * _ = plmc_or_gremlin_encoding (
@@ -1426,12 +1425,10 @@ def predict_ps( # also predicting "pmult" dict directories
14261425 )
14271426 ys_pred = get_delta_e_statistical_model (xs , x_wt )
14281427 else : # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
1429- print (len (variants ))
14301428 xs , variants , sequences , * _ = plmc_or_gremlin_encoding (
14311429 variants , sequences , None , params_file ,
14321430 threads = threads , verbose = True , substitution_sep = separator
14331431 )
1434- print ('xs len' , len (xs ), len (variants ))
14351432 if model .llm_key is None :
14361433 ys_pred = model .hybrid_prediction (xs )
14371434 else :
@@ -1448,11 +1445,15 @@ def predict_ps( # also predicting "pmult" dict directories
14481445 )
14491446
14501447
1448+ global_hybrid_model = None
1449+ global_hybrid_model_type = None
1450+
1451+
14511452def predict_directed_evolution (
14521453 encoder : str ,
14531454 variant : str ,
14541455 variant_sequence : str ,
1455- hybrid_model_data_pkl : str
1456+ hybrid_model_data_pkl : None | str
14561457) -> Union [str , list ]:
14571458 """
14581459 Perform directed in silico evolution and predict the fitness of a
@@ -1462,8 +1463,14 @@ def predict_directed_evolution(
14621463 cannot be encoded (based on the PLMC params file), returns 'skip'. Else,
14631464 returning the predicted fitness value and the variant name.
14641465 """
1466+ global global_hybrid_model , global_hybrid_model_type
14651467 if hybrid_model_data_pkl is not None :
1466- model , model_type = get_model_and_type (hybrid_model_data_pkl )
1468+ if global_hybrid_model is None :
1469+ global_hybrid_model , global_hybrid_model_type = get_model_and_type (
1470+ hybrid_model_data_pkl )
1471+ model , model_type = global_hybrid_model , global_hybrid_model_type
1472+ else :
1473+ model , model_type = global_hybrid_model , global_hybrid_model_type
14671474 else :
14681475 model_type = 'StatisticalModel' # any name != 'Hybrid'
14691476
@@ -1475,7 +1482,7 @@ def predict_directed_evolution(
14751482 return 'skip'
14761483 y_pred = get_delta_e_statistical_model (xs , x_wt )
14771484 else : # model_type == 'Hybrid': Hybrid model input requires params
1478- #from PLMC or GREMLIN model plus optional LLM input
1485+ # from PLMC or GREMLIN model plus optional LLM input
14791486 xs , variant , variant_sequence , * _ = plmc_or_gremlin_encoding (
14801487 variant , variant_sequence , None , encoder ,
14811488 verbose = False , use_global_model = True
@@ -1485,9 +1492,13 @@ def predict_directed_evolution(
14851492 if model .llm_model_input is None :
14861493 x_llm = None
14871494 else :
1488- x_llm = llm_embedder (model .llm_model_input , variant_sequence , verbose = False )
1495+ x_llm = llm_embedder (model .llm_model_input ,
1496+ variant_sequence , verbose = False )
14891497 try :
1490- y_pred = model .hybrid_prediction (np .atleast_2d (xs ), np .atleast_2d (x_llm ), verbose = False )[0 ]
1498+ y_pred = model .hybrid_prediction (
1499+ np .atleast_2d (xs ),
1500+ np .atleast_2d (x_llm ), verbose = False
1501+ )[0 ]
14911502 except ValueError as e :
14921503 raise e # TODO: Check sequences / mutations
14931504 # raise SystemError(
0 commit comments