@@ -160,10 +160,10 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
160160 hybrid_perfs = []
161161 ns_y_test = [len (variants )]
162162 for i_t , train_size in enumerate ([100 , 200 , 1000 ]):
163- prosst_lora_model = copy .deepcopy (prosst_lora_model )
164- prosst_optimizer = torch .optim .Adam (prosst_lora_model .parameters (), lr = 0.0001 )
165- esm_lora_model = copy .deepcopy (esm_lora_model )
166- esm_optimizer = torch .optim .Adam (esm_lora_model .parameters (), lr = 0.0001 )
163+ prosst_lora_model_2 = copy .deepcopy (prosst_lora_model )
164+ prosst_optimizer = torch .optim .Adam (prosst_lora_model_2 .parameters (), lr = 0.0001 )
165+ esm_lora_model_2 = copy .deepcopy (esm_lora_model )
166+ esm_optimizer = torch .optim .Adam (esm_lora_model_2 .parameters (), lr = 0.0001 )
167167 print ('\n TRAIN SIZE:' , train_size , '\n -------------------------------------------\n ' )
168168 get_vram ()
169169 try :
@@ -194,7 +194,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
194194 llm_dict_prosst = {
195195 'prosst' : {
196196 'llm_base_model' : prosst_base_model ,
197- 'llm_model' : prosst_lora_model ,
197+ 'llm_model' : prosst_lora_model_2 ,
198198 'llm_optimizer' : prosst_optimizer ,
199199 'llm_train_function' : prosst_train ,
200200 'llm_inference_function' : get_logits_from_full_seqs ,
@@ -208,7 +208,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
208208 llm_dict_esm = {
209209 'esm1v' : {
210210 'llm_base_model' : esm_base_model ,
211- 'llm_model' : esm_lora_model ,
211+ 'llm_model' : esm_lora_model_2 ,
212212 'llm_optimizer' : esm_optimizer ,
213213 'llm_train_function' : esm_train ,
214214 'llm_inference_function' : esm_infer ,
@@ -227,6 +227,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
227227 continue
228228 get_vram ()
229229 for i_m , method in enumerate ([None , llm_dict_esm , llm_dict_prosst ]):
230+ print ('~~~ ' + ['DCA hybrid' , 'DCA+ESM1v hybrid' , 'DCA+ProSST hybrid' ][i_m ] + ' ~~~' )
230231 hm = DCALLMHybridModel (
231232 x_train_dca = np .array (x_dca_train ),
232233 y_train = y_train ,
@@ -238,7 +239,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
238239
239240 y_test_pred = hm .hybrid_prediction (
240241 x_dca = np .array (x_dca_test ),
241- x_llm = [None , np .array (x_llm_test_esm ), np .array (x_llm_test_prosst )][i ]
242+ x_llm = [None , np .asarray (x_llm_test_esm ), np .asarray (x_llm_test_prosst )][i_m ]
242243 )
243244
244245 print (f'Hybrid perf.: { spearmanr (y_test , y_test_pred )[0 ]} ' )
@@ -250,7 +251,8 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
250251 f'in N_Train = { train_size } and N_Test (N_Total - N_Train).' )
251252 hybrid_perfs .append (np .nan )
252253 ns_y_test .append (np .nan )
253- del prosst_lora_model
254+ del prosst_lora_model_2
255+ del esm_lora_model_2
254256 torch .cuda .empty_cache ()
255257 gc .collect ()
256258 dt = time .time () - start_time
0 commit comments