Skip to content

Commit 1ceb458

Browse files
committed
Update hybrid model: works again (II)
1 parent 611e431 commit 1ceb458

File tree

6 files changed

+96
-172
lines changed

6 files changed

+96
-172
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,19 @@ def reduce_by_batch_modulo(a: np.ndarray, batch_size=5) -> np.ndarray:
6565
return a[:reduce]
6666

6767

68-
# TODO: Implementation of other regression techniques (CVRegression models)
68+
# TODO: Implementation of other regression techniques (CVRegression models) [Likely not worth]
6969
# TODO: Differential evolution of multiple Zero Shot predictors
70-
# (and supervised model predictions thereof) and y_true
70+
# (and supervised model predictions thereof) and y_true [DONE]
71+
# TODO: Add constrastive learning option (on PGym data)?
7172
class DCALLMHybridModel:
7273
def __init__(
7374
self,
74-
x_train_dca: np.ndarray, # DCA-encoded sequences
75-
y_train: np.ndarray, # true labels
75+
x_train_dca: np.ndarray,
76+
y_train: np.ndarray,
7677
llm_model_input: dict | None = None,
77-
x_wt: np.ndarray | None = None, # Wild type encoding
78-
alphas: np.ndarray | None = None, # Ridge regression grid for the parameter 'alpha'
79-
parameter_range: list[tuple] | None = None, # Parameter range of 'beta_1' and 'beta_2' with lower bound <= x <= upper bound,
78+
x_wt: np.ndarray | None = None,
79+
alphas: np.ndarray | None = None,
80+
parameter_range: list[tuple] | None = None,
8081
batch_size: int | None = None,
8182
device: str | None = None,
8283
seed: int | None = None
@@ -332,10 +333,10 @@ def get_subsplits_train(self, train_size_fit: float = 0.66):
332333
(train_size_fit * len(self.y_train)) -
333334
((train_size_fit * len(self.y_train)) % self.batch_size)
334335
)
335-
train_test_size = int(
336-
(len(self.y_train) - train_size_fit) -
337-
((len(self.y_train) - train_size_fit) % self.batch_size)
338-
)
336+
#train_test_size = int(
337+
# (len(self.y_train) - train_size_fit) -
338+
# ((len(self.y_train) - train_size_fit) % self.batch_size)
339+
#)
339340
(
340341
self.x_dca_ttrain, self.x_dca_ttest,
341342
self.x_llm_ttrain, self.x_llm_ttest,
@@ -347,15 +348,13 @@ def get_subsplits_train(self, train_size_fit: float = 0.66):
347348
train_size=train_size_fit,
348349
random_state=self.seed
349350
)
350-
# Reducing by batch size modulo for X, attention masks, and y
351+
# Reducing by batch size modulo for X and y
351352
self.x_dca_ttrain = self.x_dca_ttrain[:train_size_fit]
352353
self.x_llm_ttrain = self.x_llm_ttrain[:train_size_fit]
353-
#self.attn_llm_ttrain = self.attn_llm_ttrain[:train_size_fit]
354354
self.y_ttrain = self.y_ttrain[:train_size_fit]
355-
self.x_dca_ttest = self.x_dca_ttest[:train_test_size]
356-
self.x_llm_ttest = self.x_llm_ttest[:train_test_size]
357-
#self.attn_llm_ttest = self.attn_llm_ttest[:train_test_size]
358-
self.y_ttest = self.y_ttest[:train_test_size]
355+
#self.x_dca_ttest = self.x_dca_ttest[:train_test_size]
356+
#self.x_llm_ttest = self.x_llm_ttest[:train_test_size]
357+
#self.y_ttest = self.y_ttest[:train_test_size]
359358

360359
else:
361360
(
@@ -393,7 +392,7 @@ def train_llm(self):
393392
#get_batches(self.attn_llm_ttrain, batch_size=self.batch_size, dtype=int),
394393
get_batches(self.y_ttrain, batch_size=self.batch_size, dtype=float)
395394
)
396-
x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=self.batch_size, dtype=int)
395+
397396
#x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=self.batch_size, dtype=int)
398397
if self.llm_key == 'prosst':
399398
y_llm_ttest = self.llm_inference_function(
@@ -415,6 +414,7 @@ def train_llm(self):
415414
device=self.device
416415
)
417416
elif self.llm_key == 'esm1v':
417+
x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=1, dtype=int)
418418
y_llm_ttest = self.llm_inference_function(
419419
xs=x_llm_ttest_b,
420420
model=self.llm_model,
@@ -585,7 +585,6 @@ def hybrid_prediction(
585585
self.llm_attention_mask,
586586
self.structure_input_ids,
587587
train=False,
588-
#desc='Infering base model',
589588
device=self.device).detach().cpu().numpy()
590589
y_llm_lora = self.llm_inference_function(
591590
x_llm,
@@ -594,30 +593,20 @@ def hybrid_prediction(
594593
self.llm_attention_mask,
595594
self.structure_input_ids,
596595
train=False,
597-
#desc='Infering LoRA-tuned model',
598596
device=self.device).detach().cpu().numpy()
599597
elif self.llm_key == 'esm1v':
600-
x_llm_b = get_batches(x_llm, batch_size=self.batch_size, dtype=int)
598+
x_llm_b = get_batches(x_llm, batch_size=1, dtype=int)
601599
y_llm = self.llm_inference_function(
602600
x_llm_b,
603601
self.llm_attention_mask,
604602
self.llm_base_model,
605-
#desc='Infering base model',
606603
device=self.device).detach().cpu().numpy()
607604
y_llm_lora = self.llm_inference_function(
608605
x_llm_b,
609606
self.llm_attention_mask,
610607
self.llm_model,
611-
#desc='Infering LoRA-tuned model',
612608
device=self.device).detach().cpu().numpy()
613609

614-
615-
y_dca, y_ridge, y_llm, y_llm_lora = (
616-
reduce_by_batch_modulo(y_dca, batch_size=self.batch_size),
617-
reduce_by_batch_modulo(y_ridge, batch_size=self.batch_size),
618-
reduce_by_batch_modulo(y_llm, batch_size=self.batch_size),
619-
reduce_by_batch_modulo(y_llm_lora, batch_size=self.batch_size)
620-
)
621610
return self.beta1 * y_dca + self.beta2 * y_ridge + self.beta3 * y_llm + self.beta4 * y_llm_lora
622611

623612
def split_performance(

pypef/llm/esm_lora_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def esm_infer(xs, attention_mask, model, desc: None | str = None, device: str |
141141
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
142142
attention_masks = torch.Tensor(np.full(shape=np.shape(xs), fill_value=attention_mask)).to(torch.int64)
143143
print(f'Infering ESM model for predictions using {device.upper()} device...')
144-
for i , (xs_b, am_b) in enumerate(tqdm(zip(xs, attention_masks), total=len(xs), desc=desc)):
144+
for i , (xs_b, am_b) in enumerate(tqdm(zip(xs, attention_masks), total=len(xs), desc="Infering ESM model. Sequence")):
145145
xs_b = xs_b.to(torch.int64)
146146
with torch.no_grad():
147147
y_preds = get_y_pred_scores(xs_b, am_b, model, device)

pypef/llm/prosst_lora_tune.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def get_logits_from_full_seqs(
6969
attention_mask=attention_mask,
7070
ss_input_ids=structure_input_ids
7171
)
72-
7372
logits = torch.log_softmax(outputs.logits[:, 1:-1], dim=-1).squeeze()
7473
for i_s, sequence in enumerate(tqdm(xs, disable=not verbose, desc='Getting ProSST sequence logits')):
7574
for i_aa, x_aa in enumerate(sequence):
@@ -84,9 +83,6 @@ def get_logits_from_full_seqs(
8483
return log_probs
8584

8685

87-
88-
89-
9086
def checkpoint(model, filename):
9187
torch.save(model.state_dict(), filename)
9288

@@ -107,7 +103,6 @@ def prosst_train(
107103
print(f'ProSST training using {device.upper()} device (N_Train={len(torch.flatten(score_batches))})...')
108104
x_sequence_batches = x_sequence_batches.to(device)
109105
score_batches = score_batches.to(device)
110-
111106
pbar_epochs = tqdm(range(1, n_epochs + 1))
112107
epoch_spearman_1 = 0.0
113108
did_not_improve_counter = 0
@@ -191,7 +186,6 @@ def get_structure_quantizied(pdb_file, tokenizer, wt_seq):
191186
return input_ids, attention_mask, structure_input_ids
192187

193188

194-
195189
def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
196190
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
197191
prosst_vocab = prosst_tokenizer.get_vocab()
@@ -215,66 +209,3 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
215209
}
216210
}
217211
return llm_dict_prosst
218-
219-
220-
if __name__ == '__main__':
221-
import pandas as pd
222-
import copy
223-
from sklearn.model_selection import train_test_split
224-
import matplotlib.pyplot as plt
225-
# Test on dataset GRB2_HUMAN_Faure_2021: SignificanceResult(statistic=0.6997442598613315, pvalue=0.0)
226-
wt_seq = "MEAIAKYDFKATADDELSFKRGDILKVLNEECDQNWYKAELNGKDGFIPKNYIEMKPHPWFFGKIPRAKAEEMLSKQRHDGAFLIRESESAPGDFSLSVKFGNDVQHFKVLRDGAGKYFLWVVKFNSLNELVDYHRSTSVSRNQQIFLRDIEQVPQQPTYVQALFDFDPQEDGELGFRRGDFIHVMDNSDPNWWKGACHGQTGMFPRNYVTPVNRNV"
227-
grb2_folder = os.path.abspath(os.path.join(pypef_path, '..', 'datasets', 'GRB2'))
228-
pdb_file = os.path.join(grb2_folder, 'GRB2_HUMAN.pdb')
229-
csv_file = os.path.join(grb2_folder, 'GRB2_HUMAN_Faure_2021.csv')
230-
df = pd.read_csv(csv_file) #, nrows=120)
231-
print(df)
232-
prosst_base_model, prosst_lora_model, tokenizer, optimizer = get_prosst_models()
233-
vocab = tokenizer.get_vocab()
234-
structure_sequence = PdbQuantizer()(pdb_file=pdb_file)
235-
structure_sequence_offset = [i + 3 for i in structure_sequence]
236-
tokenized_res = tokenizer([wt_seq], return_tensors='pt')
237-
input_ids = tokenized_res['input_ids']
238-
attention_mask = tokenized_res['attention_mask']
239-
structure_input_ids = torch.tensor([1, *structure_sequence_offset, 2], dtype=torch.long).unsqueeze(0)
240-
#y_pred = get_logits_from_full_seqs(df['mutated_sequence'], prosst_model, input_ids, attention_mask, structure_input_ids, train=False)
241-
#print(spearmanr(df['DMS_score'], y_pred.detach().cpu().numpy())) # SignificanceResult(statistic=np.float64(0.7216670719282277), pvalue=np.float64(0.0))
242-
x_sequences = prosst_tokenize_sequences(df['mutated_sequence'], vocab=vocab)
243-
for batch_size in [5, 10, 25, 50, 100]:
244-
train_perfs_unsup, test_perfs_unsup = [], []
245-
train_perfs, test_perfs = [], []
246-
for train_size in [200, 1000, 10000]:
247-
prosst_model_copy = copy.deepcopy(prosst_base_model)
248-
x_train, x_test, scores_train, scores_test = train_test_split(
249-
x_sequences, df['DMS_score'].to_numpy().astype(float), train_size=train_size, random_state=42
250-
)
251-
print(f"\n=========================\nTRAIN SIZE: {train_size} TEST SIZE: {len(x_test)} -- BATCH SIZE: {batch_size}\n=========================")
252-
253-
y_pred = get_logits_from_full_seqs(
254-
x_test, prosst_model_copy, input_ids, attention_mask, structure_input_ids, train=False)
255-
print(f'Train-->Test UNTRAINED Performance (N={len(y_pred.flatten())}):',spearmanr(scores_test, y_pred.detach().cpu().numpy()))
256-
test_perfs_unsup.append(spearmanr(scores_test, y_pred.detach().cpu().numpy()))
257-
258-
259-
y_preds_train_unsup = get_logits_from_full_seqs(
260-
x_train, prosst_model_copy, input_ids, attention_mask, structure_input_ids, train=False, verbose=False)
261-
y_preds_train_unsup = y_preds_train_unsup.cpu().numpy()
262-
print(f'Train-->Train UNTRAINED Performance (N={len(y_preds_train_unsup)}):', spearmanr(scores_train, y_preds_train_unsup))
263-
train_perfs_unsup.append(spearmanr(scores_train, y_preds_train_unsup)[0])
264-
265-
# TRAINING
266-
x_train_b = get_batches(x_train, dtype=int, batch_size=batch_size, verbose=True)
267-
scores_train_b = get_batches(scores_train, dtype=float, batch_size=batch_size, verbose=True)
268-
y_preds_train = prosst_train(x_train_b, scores_train_b, corr_loss, prosst_model_copy, optimizer, pdb_file, n_epochs=500)
269-
print(f'Train-->Train Performance (N={len(y_preds_train)}):', spearmanr(scores_train, y_preds_train))
270-
train_perfs.append(spearmanr(scores_train, y_preds_train)[0])
271-
272-
y_pred = get_logits_from_full_seqs(
273-
x_test, prosst_model_copy, input_ids, attention_mask, structure_input_ids, train=False)
274-
print(f'Train-->Test Performance (N={len(y_pred.flatten())}):', spearmanr(scores_test, y_pred.detach().cpu().numpy()))
275-
test_perfs.append(spearmanr(scores_test, y_pred.detach().cpu().numpy())[0])
276-
for k in [train_perfs_unsup, train_perfs, test_perfs_unsup, test_perfs]:
277-
plt.plot(range(len(k)), k, label=f'Batch size: {batch_size}')
278-
plt.xticks(range(len(k)), [100, 200, 1000, 10000])
279-
plt.legend()
280-
plt.savefig('1.png')

scripts/ProteinGym_runs/results/dca_esm_and_hybrid_opt_results.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ No.,Dataset,N_Variants,N_Max_Muts,Untrained_Performance_DCA,Untrained_Performanc
99
8,A4GRB6_PSEAI_Chen_2020,5004,1,0.6681056494435768,0.543247747835155,0.647351733217166,0.6105347017831526,0.6007582102931497,0.7414608515883568,0.7053638375627392,0.6894205591814954,0.7677593077079674,0.7245479703656822,0.8059450782290521,0.8107707930340956,5004,4904,4804,4004,1751
1010
9,AACC1_PSEAI_Dandage_2018,1801,1,0.3180712414525488,0.45793953382550573,0.36853292069676097,0.3174612456170627,0.45937287821213785,0.4310318712519802,0.3521756690003161,0.49206057023996924,0.4617532190070763,0.4488626249244256,0.5505813182399806,0.534569265039648,1801,1701,1601,801,869
1111
10,ACE2_HUMAN_Chan_2020,2223,1,0.24320754065919856,0.1855938942334426,0.2613054581997969,0.2985805551410494,0.24485718938989934,0.353286631689331,0.4023145866828806,0.3372015473315942,0.4700770240532049,0.5643356952550576,0.6012504479478733,0.610115781454495,2223,2123,2023,1223,3995
12+
11,ADRB2_HUMAN_Jones_2020,7800,1,0.5187856047925657,0.5310582600087359,0.5151672046363325,0.5183272590334468,0.530515316113512,0.5380119736507377,0.5174727374036995,0.5384330679901763,0.5197786372610751,0.5147111407689821,0.5570931245349682,0.553912098700521,7800,7700,7600,6800,3287
13+
12,AICDA_HUMAN_Gajula_2014_3cycles,209,1,0.41950521618921593,0.4075489898558927,0.274172920796004,0.4423419549584822,0.5069492105992409,0.4054703600769246,nan,nan,nan,nan,nan,nan,209,109,nan,nan,96

scripts/ProteinGym_runs/results/dca_esm_and_hybrid_opt_results_clean.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ No.,Dataset,N_Variants,N_Max_Muts,Untrained_Performance_DCA,Untrained_Performanc
99
8,A4GRB6_PSEAI_Chen_2020,5004,1,0.6681056494435768,0.543247747835155,0.647351733217166,0.6105347017831526,0.6007582102931497,0.7414608515883568,0.7053638375627392,0.6894205591814954,0.7677593077079674,0.7245479703656822,0.8059450782290521,0.8107707930340956,5004,4904,4804,4004,1751
1010
9,AACC1_PSEAI_Dandage_2018,1801,1,0.3180712414525488,0.45793953382550573,0.36853292069676097,0.3174612456170627,0.45937287821213785,0.4310318712519802,0.3521756690003161,0.49206057023996924,0.4617532190070763,0.4488626249244256,0.5505813182399806,0.534569265039648,1801,1701,1601,801,869
1111
10,ACE2_HUMAN_Chan_2020,2223,1,0.24320754065919856,0.1855938942334426,0.2613054581997969,0.2985805551410494,0.24485718938989934,0.353286631689331,0.4023145866828806,0.3372015473315942,0.4700770240532049,0.5643356952550576,0.6012504479478733,0.610115781454495,2223,2123,2023,1223,3995
12+
11,ADRB2_HUMAN_Jones_2020,7800,1,0.5187856047925657,0.5310582600087359,0.5151672046363325,0.5183272590334468,0.530515316113512,0.5380119736507377,0.5174727374036995,0.5384330679901763,0.5197786372610751,0.5147111407689821,0.5570931245349682,0.553912098700521,7800,7700,7600,6800,3287
13+
12,AICDA_HUMAN_Gajula_2014_3cycles,209,1,0.41950521618921593,0.4075489898558927,0.274172920796004,0.4423419549584822,0.5069492105992409,0.4054703600769246,nan,nan,nan,nan,nan,nan,209,109,nan,nan,96

0 commit comments

Comments
 (0)