Skip to content

Commit c5bc869

Browse files
committed
dev/fail: add full-sequence-log-likelihodd plm inference func
1 parent c8d4b28 commit c5bc869

File tree

7 files changed

+323
-370
lines changed

7 files changed

+323
-370
lines changed

pypef/gaussian_process/composite.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
from sklearn.model_selection import train_test_split
33
import torch
44

5-
from gpytorch.kernels import ScaleKernel
65
import gpytorch
6+
from gpytorch.kernels import ScaleKernel
77
import pandas as pd
88
from tqdm import tqdm
99

10-
from gp_esm2_test import extract_esm_embeddings
11-
from gp_pmpnn_test import HellingerRBFKernel, get_probs_from_mutations
12-
from gp_prosst_test import (extract_prosst_embeddings, get_prosst_models,
10+
from pypef.gaussian_process.gp_esm2_test import extract_esm_embeddings
11+
from pypef.gaussian_process.gp_pmpnn_test import HellingerRBFKernel, get_probs_from_mutations
12+
from pypef.gaussian_process.gp_prosst_test import (extract_prosst_embeddings, get_prosst_models,
1313
get_structure_quantizied, read_fasta_biopython)
14-
from metrics import spearman_soft, spearman_corr_differentiable, spearmanr2
14+
from pypef.gaussian_process.metrics import spearman_soft, spearman_corr_differentiable, spearmanr2
1515

1616
class CombinedKernel(gpytorch.kernels.Kernel):
1717
"""

pypef/hybrid/hybrid_model.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,15 @@ def train_llm(self):
420420
model=self.llm_base_model,
421421
wt_input_ids=self.wt_input_ids,
422422
attention_mask=self.llm_attention_mask,
423-
structure_input_ids=self.structure_input_ids,
423+
wt_structure_input_ids=self.structure_input_ids,
424424
device=self.device
425425
)
426426
y_llm_ttrain = self.llm_inference_function(
427427
xs=self.x_llm_ttrain,
428428
model=self.llm_base_model,
429429
wt_input_ids=self.wt_input_ids,
430430
attention_mask=self.llm_attention_mask,
431-
structure_input_ids=self.structure_input_ids,
431+
wt_structure_input_ids=self.structure_input_ids,
432432
device=self.device
433433
)
434434
elif self.llm_key == 'esm1v':
@@ -472,14 +472,14 @@ def train_llm(self):
472472
# void function, training model in place
473473
if self.llm_key == 'prosst':
474474
self.llm_train_function(
475-
self.x_llm_ttrain,
476-
self.y_ttrain,
477-
self.llm_loss_function,
478-
self.llm_model,
479-
self.llm_optimizer,
480-
self.wt_input_ids,
481-
self.llm_attention_mask,
482-
self.structure_input_ids,
475+
x_sequences=self.x_llm_ttrain,
476+
scores=self.y_ttrain,
477+
loss_fn=self.llm_loss_function,
478+
model=self.llm_model,
479+
optimizer=self.llm_optimizer,
480+
wt_input_ids=self.wt_input_ids,
481+
attention_mask=self.llm_attention_mask,
482+
wt_structure_input_ids=self.structure_input_ids,
483483
n_epochs=50,
484484
device=self.device,
485485
verbose=self.verbose,
@@ -490,18 +490,18 @@ def train_llm(self):
490490
y_llm_lora_ttrain = self.llm_inference_function(
491491
xs=self.x_llm_ttrain,
492492
model=self.llm_model,
493-
input_ids=self.wt_input_ids,
493+
wt_input_ids=self.wt_input_ids,
494494
attention_mask=self.llm_attention_mask,
495-
structure_input_ids=self.structure_input_ids,
495+
wt_structure_input_ids=self.structure_input_ids,
496496
device=self.device,
497497
verbose=self.verbose
498498
)
499499
y_llm_lora_ttest = self.llm_inference_function(
500500
xs=self.x_llm_ttest,
501501
model=self.llm_model,
502-
input_ids=self.wt_input_ids,
502+
wt_input_ids=self.wt_input_ids,
503503
attention_mask=self.llm_attention_mask,
504-
structure_input_ids=self.structure_input_ids,
504+
wt_structure_input_ids=self.structure_input_ids,
505505
device=self.device,
506506
verbose=self.verbose
507507
)

pypef/plm/esm_lora_tune.py

Lines changed: 1 addition & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,12 @@
2121
logger = logging.getLogger('pypef.llm.esm_lora_tune')
2222

2323
import torch
24-
import numpy as np
25-
from scipy.stats import spearmanr
26-
from tqdm import tqdm
2724

2825
from peft import LoraConfig, get_peft_model
2926
from transformers import logging as hf_logging
3027
hf_logging.set_verbosity_error()
3128

32-
from pypef.utils.helpers import get_device
33-
from pypef.plm.utils import corr_loss, get_batches, load_model_and_tokenizer
29+
from pypef.plm.utils import load_model_and_tokenizer
3430

3531

3632
def get_esm_models(model='facebook/esm1v_t33_650M_UR90S_3'):
@@ -43,140 +39,3 @@ def get_esm_models(model='facebook/esm1v_t33_650M_UR90S_3'):
4339
lora_model = get_peft_model(base_model, peft_config)
4440
optimizer = torch.optim.Adam(lora_model.parameters(), lr=0.01)
4541
return base_model, lora_model, tokenizer, optimizer
46-
47-
48-
49-
def get_y_pred_scores(encoded_sequences, attention_masks,
50-
model, device: str | None = None):
51-
if device is None:
52-
device = get_device()
53-
model = model.to(device)
54-
out = model(encoded_sequences.to(device), attention_masks.to(device),
55-
output_hidden_states=True)
56-
logits = out.logits
57-
token_probs = torch.log_softmax(logits, dim=-1)
58-
for i_s, sequence in enumerate(encoded_sequences):
59-
for i_aa, aa in enumerate(sequence):
60-
# alternative: use Tensor.index_select() function
61-
if i_aa == 0:
62-
seq_log_probs = token_probs[i_s, i_aa, aa].reshape(1)
63-
else:
64-
seq_log_probs = torch.cat(
65-
(seq_log_probs, token_probs[i_s, i_aa, aa].reshape(1)), 0)
66-
if i_s == 0:
67-
log_probs = torch.sum(torch.Tensor(seq_log_probs)).reshape(1)
68-
else:
69-
log_probs = torch.cat(
70-
(log_probs, torch.sum(torch.Tensor(seq_log_probs)).reshape(1)), 0)
71-
return log_probs
72-
73-
74-
def esm_test(xs, attention_mask, scores, loss_fn, model,
75-
device: str | None = None, verbose: bool = True):
76-
if device is None:
77-
device = get_device()
78-
attention_masks = torch.Tensor(np.full(
79-
shape=np.shape(xs), fill_value=attention_mask)).to(torch.int64)
80-
logger.info(f'Infering ESM model for testing using {device.upper()} device...')
81-
model = model.to(device)
82-
xs, attention_masks, scores = (
83-
torch.Tensor(xs).to(device), attention_masks.to(device),
84-
torch.Tensor(scores).to(torch.float).to(device)
85-
)
86-
pbar_epochs = tqdm(zip(xs, attention_masks, scores), total=len(xs), disable=not verbose)
87-
for i ,(xs_b, attns_b, scores_b) in enumerate(pbar_epochs):
88-
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
89-
with torch.no_grad():
90-
y_preds = get_y_pred_scores(xs_b, attns_b, model, device)
91-
if i == 0:
92-
y_preds_total = y_preds
93-
scores_total = scores_b
94-
else:
95-
y_preds_total = torch.cat((y_preds_total, y_preds))
96-
scores_total = torch.cat((scores_total, scores_b))
97-
batch_loss = loss_fn(scores_b, y_preds)
98-
total_loss = loss_fn(torch.flatten(scores_total), torch.flatten(y_preds_total))
99-
batch_scorr = spearmanr(scores_b.cpu(), y_preds.cpu())[0]
100-
total_scorr = spearmanr(scores_total.cpu(), y_preds_total.cpu())[0]
101-
pbar_epochs.set_description(
102-
f"Testing: Batch {i + 1}/{len(xs)} | Batch loss: {batch_loss:.4f} (SpearCorr: "
103-
f"{batch_scorr:.4f})| Total loss: {total_loss:.4f} (SpearCorr: {total_scorr:.4f})")
104-
logger.info(f"Test performance: Loss: {total_loss:.4f}, SpearCorr: {total_scorr:.4f} "
105-
f"({device.upper()})")
106-
return torch.flatten(scores).detach().cpu(), torch.flatten(y_preds_total).detach().cpu()
107-
108-
109-
def esm_infer(xs, attention_mask, model, device: str | None = None, verbose=False):
110-
if device is None:
111-
device = get_device()
112-
attention_masks = torch.Tensor(np.full(
113-
shape=np.shape(xs), fill_value=attention_mask)).to(torch.int64)
114-
if verbose:
115-
logger.info(f'Infering ESM model for predictions using {device.upper()} device...')
116-
for i , (xs_b, am_b) in enumerate(tqdm(
117-
zip(xs, attention_masks), total=len(xs),
118-
desc=f"ESM inference - processing sequences ({device.upper()})",
119-
disable=not verbose
120-
)):
121-
xs_b = xs_b.to(torch.int64)
122-
with torch.no_grad():
123-
y_preds = get_y_pred_scores(xs_b, am_b, model, device)
124-
if i == 0:
125-
y_preds_total = y_preds
126-
else:
127-
y_preds_total = torch.cat((y_preds_total, y_preds))
128-
return torch.flatten(y_preds_total)
129-
130-
131-
def esm_train(
132-
xs, attention_mask, scores, loss_fn, model, optimizer, n_epochs=3,
133-
device: str | None = None, seed: int | None = None,
134-
n_batch_grad_accumulations: int = 1, verbose: bool = True,
135-
progress_cb=None, abort_cb=None
136-
):
137-
if seed is not None:
138-
torch.manual_seed(seed)
139-
if device is None:
140-
device = get_device()
141-
print(f'Training ESM model using {device.upper()} device '
142-
f'(N_Train={len(torch.flatten(scores))})...')
143-
model = model.to(device)
144-
attention_masks = torch.Tensor(np.full(
145-
shape=np.shape(xs), fill_value=attention_mask)).to(torch.int64)
146-
xs, attention_masks, scores = xs.to(device), attention_masks.to(device), scores.to(device)
147-
pbar_epochs = tqdm(range(1, n_epochs + 1), disable=not verbose)
148-
loss = np.nan
149-
for epoch in pbar_epochs:
150-
try:
151-
pbar_epochs.set_description(f'Epoch: {epoch}/{n_epochs}. Loss: {loss.detach():>1f}')
152-
except AttributeError:
153-
pbar_epochs.set_description(f'Epoch: {epoch}/{n_epochs}')
154-
model.train()
155-
pbar_batches = tqdm(
156-
zip(xs, attention_masks, scores),
157-
total=len(xs), leave=False, disable=not verbose
158-
)
159-
for batch, (xs_b, attns_b, scores_b) in enumerate(pbar_batches):
160-
if abort_cb and abort_cb():
161-
return
162-
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
163-
y_preds_b = get_y_pred_scores(xs_b, attns_b, model, device=device)
164-
loss = loss_fn(scores_b, y_preds_b) / n_batch_grad_accumulations
165-
if progress_cb:
166-
progress_cb(epoch - 1, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
167-
loss.backward()
168-
if (batch + 1) % n_batch_grad_accumulations == 0 or (batch + 1) == len(pbar_batches):
169-
optimizer.step()
170-
optimizer.zero_grad()
171-
pbar_batches.set_description(
172-
f"Epoch: {epoch}. Loss: {loss.detach():>1f} "
173-
f"[batch: {batch+1}/{len(xs)} | sequence: "
174-
f"{(batch + 1) * len(xs_b):>5d}/{len(xs) * len(xs_b)}] ({device.upper()})"
175-
)
176-
if progress_cb:
177-
progress_cb(epoch, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
178-
y_preds_b = y_preds_b.detach()
179-
model.train(False)
180-
181-
182-

0 commit comments

Comments
 (0)