Skip to content

Commit 31c15e2

Browse files
committed
dca_llm_hybrid PGym for all three hybrid options
1 parent 6ec1967 commit 31c15e2

File tree

3 files changed

+40
-29
lines changed

3 files changed

+40
-29
lines changed

pypef/dca/gremlin_inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@
5959
import pandas as pd
6060
from tqdm import tqdm
6161
import torch
62-
# Uncomment to hide GPU devices
63-
#environ['CUDA_VISIBLE_DEVICES'] = '-1'
6462

6563

6664
class GREMLIN:

pypef/hybrid/hybrid_model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
else:
116116
print("No LLM inputs were defined for hybrid modelling. "
117117
"Using only DCA for hybrid modeling...")
118+
self.llm_attention_mask = None
118119
if parameter_range is None:
119120
parameter_range = [(0, 1), (0, 1)]
120121
if alphas is None:
@@ -337,12 +338,10 @@ def get_subsplits_train(self, train_size_fit: float = 0.66):
337338
(
338339
self.x_dca_ttrain, self.x_dca_ttest,
339340
self.x_llm_ttrain, self.x_llm_ttest,
340-
#self.attn_llm_ttrain, self.attn_llm_ttest,
341341
self.y_ttrain, self.y_ttest
342342
) = train_test_split(
343343
self.x_train_dca,
344344
self.x_train_llm,
345-
#self.llm_attention_mask,
346345
self.y_train,
347346
train_size=train_size_fit,
348347
random_state=self.seed
@@ -525,8 +524,7 @@ def train_and_optimize(self) -> tuple:
525524
def hybrid_prediction(
526525
self,
527526
x_dca: np.ndarray,
528-
x_llm: None | np.ndarray,
529-
attns_llm: None | np.ndarray
527+
x_llm: None | np.ndarray
530528
) -> np.ndarray:
531529
"""
532530
Use the regressor 'reg' and the parameters 'beta_1'
@@ -555,7 +553,13 @@ def hybrid_prediction(
555553
else:
556554
y_ridge = self.ridge_opt.predict(x_dca)
557555

558-
if x_llm is None or attns_llm is None:
556+
if x_llm is None:
557+
if self.llm_attention_mask is not None:
558+
print('No LLM input for hybrid prediction but the model '
559+
'has been trained using an LLM model input.. '
560+
'Using only DCA for hybridprediction.. This can lead '
561+
'to unwanted prediction behavior if the hybrid model '
562+
'is trained including an LLM...')
559563
return self.beta1 * y_dca + self.beta2
560564

561565
else:

scripts/ProteinGym_runs/run_performance_tests_proteingym_hybrid_dca_llm.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,24 @@
1616
import sys # Use local directory PyPEF files
1717
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
1818
from pypef.dca.gremlin_inference import GREMLIN
19-
from pypef.llm.esm_lora_tune import get_esm_models, esm_tokenize_sequences, get_batches, esm_train, esm_test, esm_infer, corr_loss
19+
from pypef.llm.esm_lora_tune import (
20+
get_esm_models, esm_tokenize_sequences,
21+
get_batches, esm_train, esm_infer, corr_loss
22+
)
2023
from pypef.llm.prosst_lora_tune import (
21-
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied, prosst_tokenize_sequences, prosst_train)
24+
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied,
25+
prosst_tokenize_sequences, prosst_train
26+
)
2227
from pypef.utils.variant_data import get_seqs_from_var_name
23-
from pypef.hybrid.hybrid_model import DCALLMHybridModel, reduce_by_batch_modulo, get_delta_e_statistical_model
28+
from pypef.hybrid.hybrid_model import (
29+
DCALLMHybridModel, reduce_by_batch_modulo, get_delta_e_statistical_model
30+
)
2431

2532

2633
def get_vram(verbose: bool = True):
34+
if not torch.cuda.is_available():
35+
print("No CUDA/GPU device available for VRAM checking...")
36+
return
2737
free = torch.cuda.mem_get_info()[0] / 1024 ** 3
2838
total = torch.cuda.mem_get_info()[1] / 1024 ** 3
2939
total_cubes = 24
@@ -105,8 +115,9 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
105115
# Only model sequences with length of max. 800 amino acids to avoid out of memory errors
106116
print('Sequence length:', len(wt_seq))
107117
if len(wt_seq) > MAX_WT_SEQUENCE_LENGTH:
108-
print(f'Sequence length over {MAX_WT_SEQUENCE_LENGTH}, which represents a potential out-of-memory risk '
109-
f'(when running on GPU, set threshold to length ~400 dependent on available VRAM), '
118+
print(f'Sequence length over {MAX_WT_SEQUENCE_LENGTH}, which represents '
119+
f'a potential out-of-memory risk (when running on GPU, set '
120+
f'threshold to length ~400 dependent on available VRAM); '
110121
f'skipping dataset...')
111122
continue
112123
variants, variants_split, sequences, fitnesses = (
@@ -160,13 +171,11 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
160171
x_dca_train, x_dca_test,
161172
x_llm_train_prosst, x_llm_test_prosst,
162173
x_llm_train_esm, x_llm_test_esm,
163-
#attns_train, attns_test,
164174
y_train, y_test
165175
) = train_test_split(
166176
x_dca,
167177
x_prosst,
168178
x_esm,
169-
#attention_mask,
170179
fitnesses,
171180
train_size=train_size,
172181
random_state=42
@@ -217,24 +226,24 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
217226
ns_y_test.append(np.nan)
218227
continue
219228
get_vram()
220-
hm = DCALLMHybridModel(
221-
x_train_dca=np.array(x_dca_train),
222-
y_train=y_train,
223-
llm_model_input=llm_dict_esm,#prosst,
224-
x_wt=x_wt
225-
)
229+
for i_m, method in enumerate([None, llm_dict_esm, llm_dict_prosst]):
230+
hm = DCALLMHybridModel(
231+
x_train_dca=np.array(x_dca_train),
232+
y_train=y_train,
233+
llm_model_input=method,
234+
x_wt=x_wt
235+
)
226236

227-
y_test = reduce_by_batch_modulo(y_test)
237+
y_test = reduce_by_batch_modulo(y_test)
228238

229-
y_test_pred = hm.hybrid_prediction(
230-
x_dca=np.array(x_dca_test),
231-
x_llm=np.array(x_llm_test_esm),
232-
attns_llm=prosst_attention_mask # np.array(attns_test)
233-
)
239+
y_test_pred = hm.hybrid_prediction(
240+
x_dca=np.array(x_dca_test),
241+
x_llm=[None, np.array(x_llm_test_esm), np.array(x_llm_test_prosst)][i]
242+
)
234243

235-
print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}')
236-
hybrid_perfs.append(spearmanr(y_test, y_test_pred)[0])
237-
ns_y_test.append(len(y_test_pred))
244+
print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}')
245+
hybrid_perfs.append(spearmanr(y_test, y_test_pred)[0])
246+
ns_y_test.append(len(y_test_pred))
238247
except ValueError as e:
239248
#raise e
240249
print(f'Only {len(fitnesses)} variant-fitness pairs in total, cannot split the data '

0 commit comments

Comments
 (0)