|
16 | 16 | import sys # Use local directory PyPEF files |
17 | 17 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) |
18 | 18 | from pypef.dca.gremlin_inference import GREMLIN |
19 | | -from pypef.llm.esm_lora_tune import get_esm_models, esm_tokenize_sequences, get_batches, esm_train, esm_test, esm_infer, corr_loss |
| 19 | +from pypef.llm.esm_lora_tune import ( |
| 20 | + get_esm_models, esm_tokenize_sequences, |
| 21 | + get_batches, esm_train, esm_infer, corr_loss |
| 22 | +) |
20 | 23 | from pypef.llm.prosst_lora_tune import ( |
21 | | - get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied, prosst_tokenize_sequences, prosst_train) |
| 24 | + get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied, |
| 25 | + prosst_tokenize_sequences, prosst_train |
| 26 | +) |
22 | 27 | from pypef.utils.variant_data import get_seqs_from_var_name |
23 | | -from pypef.hybrid.hybrid_model import DCALLMHybridModel, reduce_by_batch_modulo, get_delta_e_statistical_model |
| 28 | +from pypef.hybrid.hybrid_model import ( |
| 29 | + DCALLMHybridModel, reduce_by_batch_modulo, get_delta_e_statistical_model |
| 30 | +) |
24 | 31 |
|
25 | 32 |
|
26 | 33 | def get_vram(verbose: bool = True): |
| 34 | + if not torch.cuda.is_available(): |
| 35 | + print("No CUDA/GPU device available for VRAM checking...") |
| 36 | + return |
27 | 37 | free = torch.cuda.mem_get_info()[0] / 1024 ** 3 |
28 | 38 | total = torch.cuda.mem_get_info()[1] / 1024 ** 3 |
29 | 39 | total_cubes = 24 |
@@ -105,8 +115,9 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested |
105 | 115 | # Only model sequences with length of max. 800 amino acids to avoid out of memory errors |
106 | 116 | print('Sequence length:', len(wt_seq)) |
107 | 117 | if len(wt_seq) > MAX_WT_SEQUENCE_LENGTH: |
108 | | - print(f'Sequence length over {MAX_WT_SEQUENCE_LENGTH}, which represents a potential out-of-memory risk ' |
109 | | - f'(when running on GPU, set threshold to length ~400 dependent on available VRAM), ' |
| 118 | + print(f'Sequence length over {MAX_WT_SEQUENCE_LENGTH}, which represents ' |
| 119 | + f'a potential out-of-memory risk (when running on GPU, set ' |
| 120 | + f'threshold to length ~400 dependent on available VRAM); ' |
110 | 121 | f'skipping dataset...') |
111 | 122 | continue |
112 | 123 | variants, variants_split, sequences, fitnesses = ( |
@@ -160,13 +171,11 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested |
160 | 171 | x_dca_train, x_dca_test, |
161 | 172 | x_llm_train_prosst, x_llm_test_prosst, |
162 | 173 | x_llm_train_esm, x_llm_test_esm, |
163 | | - #attns_train, attns_test, |
164 | 174 | y_train, y_test |
165 | 175 | ) = train_test_split( |
166 | 176 | x_dca, |
167 | 177 | x_prosst, |
168 | 178 | x_esm, |
169 | | - #attention_mask, |
170 | 179 | fitnesses, |
171 | 180 | train_size=train_size, |
172 | 181 | random_state=42 |
@@ -217,24 +226,24 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested |
217 | 226 | ns_y_test.append(np.nan) |
218 | 227 | continue |
219 | 228 | get_vram() |
220 | | - hm = DCALLMHybridModel( |
221 | | - x_train_dca=np.array(x_dca_train), |
222 | | - y_train=y_train, |
223 | | - llm_model_input=llm_dict_esm,#prosst, |
224 | | - x_wt=x_wt |
225 | | - ) |
| 229 | + for i_m, method in enumerate([None, llm_dict_esm, llm_dict_prosst]): |
| 230 | + hm = DCALLMHybridModel( |
| 231 | + x_train_dca=np.array(x_dca_train), |
| 232 | + y_train=y_train, |
| 233 | + llm_model_input=method, |
| 234 | + x_wt=x_wt |
| 235 | + ) |
226 | 236 |
|
227 | | - y_test = reduce_by_batch_modulo(y_test) |
| 237 | + y_test = reduce_by_batch_modulo(y_test) |
228 | 238 |
|
229 | | - y_test_pred = hm.hybrid_prediction( |
230 | | - x_dca=np.array(x_dca_test), |
231 | | - x_llm=np.array(x_llm_test_esm), |
232 | | - attns_llm=prosst_attention_mask # np.array(attns_test) |
233 | | - ) |
| 239 | + y_test_pred = hm.hybrid_prediction( |
| 240 | + x_dca=np.array(x_dca_test), |
| 241 | + x_llm=[None, np.array(x_llm_test_esm), np.array(x_llm_test_prosst)][i] |
| 242 | + ) |
234 | 243 |
|
235 | | - print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}') |
236 | | - hybrid_perfs.append(spearmanr(y_test, y_test_pred)[0]) |
237 | | - ns_y_test.append(len(y_test_pred)) |
| 244 | + print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}') |
| 245 | + hybrid_perfs.append(spearmanr(y_test, y_test_pred)[0]) |
| 246 | + ns_y_test.append(len(y_test_pred)) |
238 | 247 | except ValueError as e: |
239 | 248 | #raise e |
240 | 249 | print(f'Only {len(fitnesses)} variant-fitness pairs in total, cannot split the data ' |
|
0 commit comments