@@ -117,6 +117,7 @@ def __init__(
117117 else :
118118 print ("No LLM inputs were defined for hybrid modelling. "
119119 "Using only DCA for hybrid modeling..." )
120+ self .llm_model_input = llm_model_input # = None
120121 self .llm_attention_mask = None
121122 if parameter_range is None :
122123 parameter_range = [(0 , 1 ), (0 , 1 )]
@@ -545,7 +546,7 @@ def train_and_optimize(self) -> tuple:
545546 def hybrid_prediction (
546547 self ,
547548 x_dca : np .ndarray ,
548- x_llm : None | np .ndarray
549+ x_llm : None | np .ndarray = None
549550 ) -> np .ndarray :
550551 """
551552 Use the regressor 'reg' and the parameters 'beta_1'
@@ -735,7 +736,7 @@ def get_model_path(model: str):
735736 else :
736737 raise SystemError (
737738 "Did not find specified model file in current "
738- "working directory or /Pickles subdirectory. "
739+ "working directory or /Pickles subdirectory. "
739740 "Make sure to train/save a model first (e.g., "
740741 "for saving a GREMLIN model, type \" pypef "
741742 "param_inference --msa TARGET_MSA.a2m\" or, for"
@@ -798,6 +799,7 @@ def save_model_to_dict_pickle(
798799 model_type = 'MODEL'
799800
800801 pkl_path = os .path .abspath (f'Pickles/{ model_type } ' )
802+ # TODO: For LLM model saves try: model.state_dict()
801803 pickle .dump (
802804 {
803805 'model' : model ,
@@ -1326,7 +1328,7 @@ def predict_ps( # also predicting "pmult" dict directories
13261328 all_y_v_pred = []
13271329 files = [f for f in listdir (path ) if isfile (join (path , f )) if f .endswith ('.fasta' )]
13281330 for i , file in enumerate (files ): # collect and predict for each file in the directory
1329- print (f'Encoding files ({ i + 1 } /{ len (files )} ) for prediction...\n ' )
1331+ print (f'Encoding files ({ i + 1 } /{ len (files )} ) for prediction...' )
13301332 file_path = os .path .join (path , file )
13311333 sequences , variants , _ = get_sequences_from_file (file_path )
13321334 if model_type != 'Hybrid' :
@@ -1359,7 +1361,7 @@ def predict_ps( # also predicting "pmult" dict directories
13591361 else : # check next task to do, e.g., predicting triple substituted variants, e.g. trecomb
13601362 continue
13611363
1362- elif prediction_set is not None :
1364+ elif prediction_set is not None : # Predicting single FASTA file sequences
13631365 sequences , variants , _ = get_sequences_from_file (prediction_set )
13641366 # NaNs are already being removed by the called function
13651367 if model_type != 'Hybrid' : # statistical DCA model
0 commit comments