Skip to content

Commit c57879e

Browse files
committed
dev/fail: further test implementation of rain_plm() (I/X)
1 parent fd3e25b commit c57879e

File tree

7 files changed

+217
-52
lines changed

7 files changed

+217
-52
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@ def __init__(
8484
self.llm_base_model = llm_model_input['esm1v']['llm_base_model']
8585
self.llm_model = llm_model_input['esm1v']['llm_model']
8686
self.llm_optimizer = llm_model_input['esm1v']['llm_optimizer']
87-
#self.llm_train_function = llm_model_input['esm1v']['llm_train_function']
87+
self.llm_train_function = llm_model_input['esm1v']['llm_train_function']
8888
self.llm_inference_function = llm_model_input['esm1v']['llm_inference_function']
8989
self.llm_loss_function = llm_model_input['esm1v']['llm_loss_function']
9090
self.x_train_llm = llm_model_input['esm1v']['x_llm']
91+
self.wt_input_ids = llm_model_input['esm1v']['wt_input_ids']
9192
self.llm_attention_mask = llm_model_input['esm1v']['llm_attention_mask']
9293
elif len(list(llm_model_input.keys())) == 1 and list(llm_model_input.keys())[0] == 'prosst':
9394
self.llm_key = 'prosst'
9495
self.llm_base_model = llm_model_input['prosst']['llm_base_model']
9596
self.llm_model = llm_model_input['prosst']['llm_model']
9697
self.llm_optimizer = llm_model_input['prosst']['llm_optimizer']
97-
#self.llm_train_function = llm_model_input['prosst']['llm_train_function']
98+
self.llm_train_function = llm_model_input['prosst']['llm_train_function']
9899
self.llm_inference_function = llm_model_input['prosst']['llm_inference_function']
99100
self.llm_loss_function = llm_model_input['prosst']['llm_loss_function']
100101
self.x_train_llm = llm_model_input['prosst']['x_llm']
@@ -432,16 +433,29 @@ def train_llm(self):
432433
)
433434
elif self.llm_key == 'esm1v':
434435
x_llm_ttest_b = torch.from_numpy(get_batches(self.x_llm_ttest, batch_size=1, dtype=int))
436+
#xs,
437+
#wt_input_ids,
438+
#attention_mask,
439+
#model,
440+
#mask_token_id = None,
441+
#inference_type='unmasked',
442+
#wt_structure_input_ids=None,
443+
#batch_size=5,
444+
#train=False,
445+
#device=None,
446+
#verbose=False,
435447
y_llm_ttest = self.llm_inference_function(
436-
xs=x_llm_ttest_b,
437-
model=self.llm_model,
448+
xs=self.x_llm_ttest,
449+
wt_input_ids=self.wt_input_ids,
438450
attention_mask=self.llm_attention_mask,
451+
model=self.llm_model,
439452
device=self.device
440453
)
441454
y_llm_ttrain = self.llm_inference_function(
442-
xs=x_llm_ttrain_b,
443-
model=self.llm_model,
455+
xs=self.x_llm_ttrain,
456+
wt_input_ids=self.wt_input_ids,
444457
attention_mask=self.llm_attention_mask,
458+
model=self.llm_model,
445459
device=self.device
446460
)
447461
logger.info(
@@ -493,7 +507,7 @@ def train_llm(self):
493507
)
494508
elif self.llm_key == 'esm1v':
495509
# xs, attns, scores, loss_fn, model, optimizer
496-
self.llm_train_function(
510+
self.llm_train_function(
497511
x_llm_ttrain_b,
498512
self.llm_attention_mask,
499513
scores_ttrain_b,

pypef/plm/esm_lora_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_esm_models(model='facebook/esm1v_t33_650M_UR90S_3'):
4747

4848
def tokenize_sequences(sequences, tokenizer, max_length, verbose=True):
4949
tokenized_sequences = []
50-
for seq in tqdm(sequences, desc='Tokenizing sequences for ESM modeling', disable=not verbose):
50+
for seq in tqdm(sequences, desc='Tokenizing sequences', disable=not verbose):
5151
encoded_sequence, attention_mask = tokenizer(
5252
seq,
5353
padding='max_length',

pypef/plm/inference.py

Lines changed: 183 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,33 @@
44
# Some helper functions for infernece of different models
55
# based on simple/wrapping functions
66

7+
import os
8+
import inspect
79
import numpy as np
10+
from scipy.stats import spearmanr
811
import torch
912
import torch.nn.functional as F
1013
from tqdm import tqdm
14+
from Bio import SeqIO
1115

1216
from pypef.utils.helpers import get_device
1317
from pypef.plm.utils import corr_loss, get_batches
1418
from pypef.plm.esm_lora_tune import get_esm_models, tokenize_sequences
1519

20+
1621
import logging
1722
logger = logging.getLogger('pypef.llm.inference')
1823

1924

25+
def checkpoint(model, filename):
26+
torch.save(model.state_dict(), filename)
27+
28+
29+
def load_model(model, filename):
30+
logger.info(f'Loading best model: {os.path.abspath(filename)}...')
31+
model.load_state_dict(torch.load(filename, weights_only=True))
32+
33+
2034
def tokenize_sequences(sequences, tokenizer, max_length, verbose=True):
2135
tokenized_sequences = []
2236
for seq in tqdm(sequences, desc='Tokenizing sequences', disable=not verbose):
@@ -45,25 +59,34 @@ def unmasked_wt_score(
4559
device = get_device()
4660
if wt_input_ids.dim() == 1:
4761
wt_input_ids = wt_input_ids.unsqueeze(0)
62+
wt_input_ids = wt_input_ids.to(device)
4863
#structure_input_ids = model_kwargs.get("structure_input_ids", None)
4964

5065
attention_masks = torch.Tensor(np.full(
51-
shape=np.shape(wt_input_ids), fill_value=attention_mask)).to(torch.int64)
52-
if train:
53-
outputs = model(
54-
input_ids=wt_input_ids.to(device),
55-
attention_mask=attention_masks.to(device),
56-
**model_kwargs
57-
)
58-
59-
else:
60-
with torch.no_grad():
66+
shape=np.shape(wt_input_ids), fill_value=attention_mask)).to(torch.int64).to(device)
67+
try:
68+
if train:
6169
outputs = model(
62-
input_ids=wt_input_ids.to(device),
63-
attention_mask=attention_masks.to(device),
70+
input_ids=wt_input_ids,
71+
attention_mask=attention_masks,
6472
**model_kwargs
6573
)
6674

75+
else:
76+
with torch.no_grad():
77+
outputs = model(
78+
input_ids=wt_input_ids,
79+
attention_mask=attention_masks,
80+
**model_kwargs
81+
)
82+
except TypeError as e:
83+
print(f"Did not find model input keyword arguments (kwargs: "
84+
f"{model_kwargs.keys()}). Available kawrgs identified from "
85+
f"model.forward function inspect:\n"
86+
f"{inspect.signature(model.forward)}\nOriginal error:")
87+
raise e
88+
89+
6790
logits = outputs.logits
6891
logits = logits.squeeze(0) # remove batch dim
6992
# Better make sure that special tokens are always removed / masked
@@ -174,10 +197,8 @@ def mutation_only_mutation_masked_pll(
174197
output_hidden_states=False
175198
)
176199
logits = outputs.logits # (1, L, V)
177-
178200
log_probs = F.log_softmax(logits[0, pos], dim=-1)
179201
true_token = tokenized_seq[pos]
180-
181202
pll = pll + log_probs[true_token]
182203

183204
plls[i] = pll
@@ -186,8 +207,8 @@ def mutation_only_mutation_masked_pll(
186207

187208

188209
def mutation_all_pos_masked_pll(
189-
tokenized_sequences: torch.Tensor, # (L,)
190-
attention_mask: torch.Tensor, # (L,)
210+
tokenized_sequences: torch.Tensor, # (L,)
211+
attention_mask: torch.Tensor, # (L,)
191212
model,
192213
mask_token_id: int,
193214
train: bool = False,
@@ -275,7 +296,7 @@ def plm_inference(
275296
mask_token_id = None,
276297
inference_type='unmasked',
277298
wt_structure_input_ids=None,
278-
batch_size=5,
299+
batch_size: int | None = 5,
279300
train=False,
280301
device=None,
281302
verbose=False,
@@ -300,16 +321,18 @@ def plm_inference(
300321
raise SystemError("Choose between 'mutation-masking', 'unmasked', and 'full-masking'")
301322

302323
scores = []
303-
304-
xs_b = get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True)
324+
if batch_size is None:
325+
xs_b = xs
326+
else:
327+
xs_b = get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True)
305328
desc = f"Inference: {inference_type} batch (size={batch_size}) processing ({device.upper()})'"
306329

307330
kwargs = {}
308331
if mask_token_id is not None:
309332
kwargs["mask_token_id"] = mask_token_id
310333

311334
if wt_structure_input_ids is not None:
312-
kwargs["structure_input_ids"] = wt_structure_input_ids
335+
kwargs["ss_input_ids"] = wt_structure_input_ids.to(device)
313336

314337
pbar = tqdm(
315338
range(len(xs_b)),
@@ -332,6 +355,141 @@ def plm_inference(
332355
return torch.cat(scores)
333356

334357

358+
def plm_train(
359+
x_sequences,
360+
scores,
361+
loss_fn,
362+
model,
363+
optimizer,
364+
input_ids,
365+
attention_mask,
366+
batch_size: int = 5,
367+
n_epochs=50,
368+
device: str | None = None,
369+
seed: int | None = None,
370+
early_stop: int = 50,
371+
verbose: bool = True,
372+
wt_structure_input_ids=None,
373+
n_batch_grad_accumulations: int = 1,
374+
raise_error_on_train_fail: bool = True,
375+
progress_cb=None,
376+
abort_cb=None
377+
):
378+
"""
379+
TODO: Wrapper function for `plm_inference()` for PLM training.
380+
"""
381+
if seed is not None:
382+
torch.manual_seed(seed)
383+
if device is None:
384+
device = get_device()
385+
logger.info(f"ProSST training using {device.upper()} device "
386+
f"(N_Train={len(torch.flatten(score_batches))})...")
387+
x_sequences_batched = get_batches(x_sequences, dtype=int, batch_size=batch_size,
388+
keep_remaining=False, verbose=True)
389+
x_sequences_batched = x_sequences_batched.to(device)
390+
score_batches = get_batches(scores, dtype=float, batch_size=batch_size,
391+
keep_remaining=False, verbose=True)
392+
score_batches = score_batches.to(device)
393+
pbar_epochs = tqdm(range(1, n_epochs + 1), disable=not verbose)
394+
epoch_spearman_1 = -1.0
395+
did_not_improve_counter = 0
396+
best_model = None
397+
best_model_epoch = np.nan
398+
best_model_perf = np.nan
399+
loss = np.nan
400+
os.makedirs('model_saves', exist_ok=True)
401+
for epoch in pbar_epochs:
402+
if epoch == 0:
403+
pbar_epochs.set_description(f'Epoch {epoch}/{n_epochs}')
404+
model.train()
405+
y_preds_detached = []
406+
pbar_batches = tqdm(
407+
zip(x_sequences_batched, score_batches),
408+
total=len(x_sequences), leave=False, disable=not verbose
409+
)
410+
for batch, (seqs_b, scores_b) in enumerate(pbar_batches):
411+
if abort_cb and abort_cb():
412+
return
413+
y_preds_b = plm_inference(
414+
seqs_b, model, input_ids, attention_mask,
415+
train=True, verbose=False
416+
)
417+
y_preds_detached.append(y_preds_b.detach().cpu().numpy().flatten())
418+
loss = loss_fn(scores_b, y_preds_b) / n_batch_grad_accumulations
419+
if progress_cb:
420+
progress_cb(epoch - 1, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
421+
loss.backward()
422+
if (batch + 1) % n_batch_grad_accumulations == 0 or (batch + 1) == len(pbar_batches):
423+
optimizer.step()
424+
optimizer.zero_grad()
425+
pbar_batches.set_description(
426+
f"Epoch: {epoch}. Loss: {loss.detach():>1f} "
427+
f"[batch: {batch + 1}/{len(x_sequences)} | "
428+
f"sequence: {(batch + 1) * len(seqs_b):>5d}/{len(x_sequences) * len(seqs_b)}] "
429+
f"({device.upper()})"
430+
)
431+
epoch_spearman_2 = spearmanr(score_batches.cpu().numpy().flatten(),
432+
np.array(y_preds_detached).flatten())[0]
433+
if epoch_spearman_2 == np.nan:
434+
raise SystemError(
435+
f"No correlation between Y_true and Y_pred could be computed...\n"
436+
f"Y_true: {score_batches.cpu().numpy().flatten()}, "
437+
f"Y_pred: {np.array(y_preds_detached)}"
438+
)
439+
if epoch_spearman_2 > epoch_spearman_1 or epoch == 0:
440+
if best_model is not None:
441+
if os.path.isfile(best_model):
442+
os.remove(best_model)
443+
did_not_improve_counter = 0
444+
best_model_epoch = epoch
445+
best_model_perf = epoch_spearman_2
446+
best_model = (
447+
f"model_saves/Epoch{epoch}-Ntrain{len(score_batches.cpu().numpy().flatten())}"
448+
f"-SpearCorr{epoch_spearman_2:.3f}.pt"
449+
)
450+
checkpoint(model, best_model)
451+
epoch_spearman_1 = epoch_spearman_2
452+
#logger.info(f"Saved current best model as {best_model}")
453+
else:
454+
did_not_improve_counter += 1
455+
if did_not_improve_counter >= early_stop:
456+
logger.info(f'\nEarly stop at epoch {epoch}...')
457+
break
458+
loss_total = loss_fn(
459+
torch.flatten(score_batches).to('cpu'),
460+
torch.flatten(torch.Tensor(np.array(y_preds_detached).flatten()))
461+
)
462+
pbar_epochs.set_description(
463+
f'Epoch {epoch}/{n_epochs} [SpearCorr: {epoch_spearman_2:.3f}, Loss: {loss_total:.3f}] '
464+
f'(Best epoch: {best_model_epoch}: {best_model_perf:.3f})')
465+
if progress_cb:
466+
progress_cb(epoch, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
467+
if best_model is None:
468+
msg = ("Failed to train a model (probably due to the input "
469+
"data characteristics and loss/correlation being NaN).")
470+
if raise_error_on_train_fail:
471+
raise RuntimeError(msg)
472+
else:
473+
logger.warning(f"{msg} Continuing nonetheless (using failed model "
474+
f"and replacing NaN's with zeros)...")
475+
#y_preds_train = get_logits_from_full_seqs(
476+
# x_sequences.flatten(start_dim=0, end_dim=1),
477+
# model, input_ids, attention_mask, structure_input_ids, train=False, verbose=False
478+
#)
479+
#y_preds_train[torch.isnan(y_preds_train)] = 0.0
480+
else:
481+
logger.info(f"Loading best model as {best_model}...")
482+
load_model(model, best_model)
483+
#y_preds_train = get_logits_from_full_seqs(
484+
# x_sequences.flatten(start_dim=0, end_dim=1),
485+
# model, input_ids, attention_mask, structure_input_ids, train=False, verbose=False
486+
#)
487+
return #y_preds_train.cpu()
488+
489+
490+
491+
492+
335493
######################### Deprecated
336494

337495
def llm_tokenizer(llm_dict, seqs, verbose=True):
@@ -444,12 +602,12 @@ def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True
444602
'llm_base_model': esm_base_model,
445603
'llm_model': esm_lora_model,
446604
'llm_optimizer': esm_optimizer,
447-
#'llm_train_function': esm_train,
605+
'llm_train_function': plm_train,
448606
'llm_inference_function': plm_inference,
449607
'llm_loss_function': corr_loss,
450-
'x_llm' : x_esm,
451-
'input_ids': wt_tokens,
452-
'llm_attention_mask': esm_attention_mask,
608+
'x_llm' : torch.tensor(x_esm),
609+
'llm_attention_mask': torch.tensor(esm_attention_mask),
610+
'wt_input_ids': torch.tensor(wt_tokens),
453611
'llm_tokenizer': esm_tokenizer
454612
}
455613
}
@@ -491,7 +649,7 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose
491649
'llm_base_model': prosst_base_model,
492650
'llm_model': prosst_lora_model,
493651
'llm_optimizer': prosst_optimizer,
494-
#'llm_train_function': prosst_train,
652+
'llm_train_function': plm_train,
495653
'llm_inference_function': plm_inference, # prosst_infer,
496654
'llm_loss_function': corr_loss,
497655
'x_llm' : x_llm_train_prosst,

pypef/plm/prosst_lora_tune.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,6 @@ def prosst_infer(
132132
)
133133

134134

135-
def checkpoint(model, filename):
136-
torch.save(model.state_dict(), filename)
137-
138-
139-
def load_model(model, filename):
140-
logger.info(f'Loading best model: {os.path.abspath(filename)}...')
141-
model.load_state_dict(torch.load(filename, weights_only=True))
142-
143-
144135
def prosst_train(
145136
x_sequence_batches, score_batches, loss_fn, model, optimizer,
146137
input_ids, attention_mask, structure_input_ids,

pypef/plm/prosst_structure/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def process_subgraph(anchor_node):
486486
return anchor_node, subgraph
487487
for anchor_node in tqdm(
488488
anchor_nodes,
489-
desc=f'Getting ProSST structure embeddings ({device.upper()})',
489+
desc=f'Getting ProSST structure tokens ({device.upper()})',
490490
disable=not verbose
491491
):
492492
anchor, subgraph = process_subgraph(anchor_node)

0 commit comments

Comments
 (0)