Skip to content

Commit c8d4b28

Browse files
committed
dev/fail: train_plm() seems to work for ESM (II)
1 parent c57879e commit c8d4b28

File tree

5 files changed

+107
-79
lines changed

5 files changed

+107
-79
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def __init__(
100100
self.llm_loss_function = llm_model_input['prosst']['llm_loss_function']
101101
self.x_train_llm = llm_model_input['prosst']['x_llm']
102102
self.llm_attention_mask = llm_model_input['prosst']['llm_attention_mask']
103-
self.input_ids = llm_model_input['prosst']['input_ids']
103+
self.wt_input_ids = llm_model_input['prosst']['wt_input_ids']
104104
self.structure_input_ids = llm_model_input['prosst']['structure_input_ids']
105105
else:
106106
raise RuntimeError("LLM input model dictionary not supported. Currently supported "
@@ -418,15 +418,15 @@ def train_llm(self):
418418
y_llm_ttest = self.llm_inference_function(
419419
xs=self.x_llm_ttest,
420420
model=self.llm_base_model,
421-
input_ids=self.input_ids,
421+
wt_input_ids=self.wt_input_ids,
422422
attention_mask=self.llm_attention_mask,
423423
structure_input_ids=self.structure_input_ids,
424424
device=self.device
425425
)
426426
y_llm_ttrain = self.llm_inference_function(
427427
xs=self.x_llm_ttrain,
428428
model=self.llm_base_model,
429-
input_ids=self.input_ids,
429+
wt_input_ids=self.wt_input_ids,
430430
attention_mask=self.llm_attention_mask,
431431
structure_input_ids=self.structure_input_ids,
432432
device=self.device
@@ -472,12 +472,12 @@ def train_llm(self):
472472
# void function, training model in place
473473
if self.llm_key == 'prosst':
474474
self.llm_train_function(
475-
x_llm_ttrain_b,
476-
scores_ttrain_b,
475+
self.x_llm_ttrain,
476+
self.y_ttrain,
477477
self.llm_loss_function,
478478
self.llm_model,
479479
self.llm_optimizer,
480-
self.input_ids,
480+
self.wt_input_ids,
481481
self.llm_attention_mask,
482482
self.structure_input_ids,
483483
n_epochs=50,
@@ -490,7 +490,7 @@ def train_llm(self):
490490
y_llm_lora_ttrain = self.llm_inference_function(
491491
xs=self.x_llm_ttrain,
492492
model=self.llm_model,
493-
input_ids=self.input_ids,
493+
input_ids=self.wt_input_ids,
494494
attention_mask=self.llm_attention_mask,
495495
structure_input_ids=self.structure_input_ids,
496496
device=self.device,
@@ -499,37 +499,47 @@ def train_llm(self):
499499
y_llm_lora_ttest = self.llm_inference_function(
500500
xs=self.x_llm_ttest,
501501
model=self.llm_model,
502-
input_ids=self.input_ids,
502+
input_ids=self.wt_input_ids,
503503
attention_mask=self.llm_attention_mask,
504504
structure_input_ids=self.structure_input_ids,
505505
device=self.device,
506506
verbose=self.verbose
507507
)
508508
elif self.llm_key == 'esm1v':
509509
# xs, attns, scores, loss_fn, model, optimizer
510+
# x_sequences,
511+
# scores,
512+
# loss_fn,
513+
# model,
514+
# optimizer,
515+
# input_ids,
516+
# attention_mask,
510517
self.llm_train_function(
511-
x_llm_ttrain_b,
512-
self.llm_attention_mask,
513-
scores_ttrain_b,
514-
self.llm_loss_function,
515-
self.llm_model,
516-
self.llm_optimizer,
517-
n_epochs=5,
518+
x_sequences=self.x_llm_ttrain,
519+
scores=self.y_ttrain,
520+
loss_fn=self.llm_loss_function,
521+
model=self.llm_model,
522+
optimizer=self.llm_optimizer,
523+
wt_input_ids=self.wt_input_ids,
524+
attention_mask=self.llm_attention_mask,
525+
n_epochs=50,
518526
device=self.device,
519527
verbose=self.verbose,
520528
progress_cb=self.progress_cb,
521529
abort_cb=self.abort_cb
522530
)
523531
y_llm_lora_ttrain = self.llm_inference_function(
524-
xs=x_llm_ttrain_b,
532+
xs=self.x_llm_ttrain,
525533
model=self.llm_model,
526534
attention_mask=self.llm_attention_mask,
535+
wt_input_ids=self.wt_input_ids,
527536
device=self.device,
528537
verbose=self.verbose
529538
)
530539
y_llm_lora_ttest = self.llm_inference_function(
531-
xs=x_llm_ttest_b,
540+
xs=self.x_llm_ttest,
532541
model=self.llm_model,
542+
wt_input_ids=self.wt_input_ids,
533543
attention_mask=self.llm_attention_mask,
534544
device=self.device,
535545
verbose=self.verbose
@@ -630,36 +640,44 @@ def hybrid_prediction(
630640

631641
else:
632642
if self.llm_key == 'prosst':
643+
# xs,
644+
#wt_input_ids,
645+
#attention_mask,
646+
#model,
633647
y_llm = self.llm_inference_function(
634-
x_llm,
635-
self.llm_base_model,
636-
self.input_ids,
637-
self.llm_attention_mask,
638-
self.structure_input_ids,
648+
xs=x_llm,
649+
wt_input_ids=self.wt_input_ids,
650+
attention_mask=self.llm_attention_mask,
651+
model=self.llm_base_model,
652+
wt_structure_input_ids=self.structure_input_ids,
639653
verbose=verbose,
640654
device=self.device).detach().cpu().numpy()
641655
y_llm_lora = self.llm_inference_function(
642-
x_llm,
643-
self.llm_model,
644-
self.input_ids,
645-
self.llm_attention_mask,
646-
self.structure_input_ids,
656+
xs=x_llm,
657+
wt_input_ids=self.wt_input_ids,
658+
attention_mask=self.llm_attention_mask,
659+
model=self.llm_model,
660+
wt_structure_input_ids=self.structure_input_ids,
647661
verbose=verbose,
648662
device=self.device).detach().cpu().numpy()
649663
elif self.llm_key == 'esm1v':
650664
x_llm_b = torch.from_numpy(get_batches(x_llm, batch_size=1, dtype=int))
651665
y_llm = self.llm_inference_function(
652-
x_llm_b,
653-
self.llm_attention_mask,
654-
self.llm_base_model,
666+
xs=x_llm,
667+
wt_input_ids=self.wt_input_ids,
668+
attention_mask=self.llm_attention_mask,
669+
model=self.llm_base_model,
655670
verbose=verbose,
656-
device=self.device).detach().cpu().numpy()
671+
device=self.device
672+
).detach().cpu().numpy()
657673
y_llm_lora = self.llm_inference_function(
658-
x_llm_b,
659-
self.llm_attention_mask,
660-
self.llm_model,
674+
xs=x_llm,
675+
wt_input_ids=self.wt_input_ids,
676+
attention_mask=self.llm_attention_mask,
677+
model=self.llm_model,
661678
verbose=verbose,
662-
device=self.device).detach().cpu().numpy()
679+
device=self.device
680+
).detach().cpu().numpy()
663681
if np.any(np.isnan(y_llm)) or np.any(np.isnan(y_llm_lora)):
664682
logger.warning(
665683
f"LLM predictions contains NaN's... replacing NaN's with "

pypef/plm/esm_lora_tune.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,6 @@ def get_esm_models(model='facebook/esm1v_t33_650M_UR90S_3'):
4545
return base_model, lora_model, tokenizer, optimizer
4646

4747

48-
def tokenize_sequences(sequences, tokenizer, max_length, verbose=True):
49-
tokenized_sequences = []
50-
for seq in tqdm(sequences, desc='Tokenizing sequences', disable=not verbose):
51-
encoded_sequence, attention_mask = tokenizer(
52-
seq,
53-
padding='max_length',
54-
truncation=True, # False for not uniform length distribution (truncation)
55-
max_length=max_length
56-
).values()
57-
tokenized_sequences.append(encoded_sequence)
58-
return tokenized_sequences, attention_mask
59-
6048

6149
def get_y_pred_scores(encoded_sequences, attention_masks,
6250
model, device: str | None = None):

pypef/plm/inference.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from tqdm import tqdm
1414
from Bio import SeqIO
1515

16+
from pypef.plm.prosst_lora_tune import get_prosst_models, get_structure_quantizied
1617
from pypef.utils.helpers import get_device
1718
from pypef.plm.utils import corr_loss, get_batches
18-
from pypef.plm.esm_lora_tune import get_esm_models, tokenize_sequences
19+
from pypef.plm.esm_lora_tune import get_esm_models
1920

2021

2122
import logging
@@ -55,6 +56,7 @@ def unmasked_wt_score(
5556
verbose: bool = False,
5657
**model_kwargs
5758
):
59+
#print('unmasked_wt_score() tokenized_sequences.shape', tokenized_sequences.shape)
5860
if device is None:
5961
device = get_device()
6062
if wt_input_ids.dim() == 1:
@@ -322,17 +324,21 @@ def plm_inference(
322324

323325
scores = []
324326
if batch_size is None:
325-
xs_b = xs
327+
xs_b = torch.atleast_2d(xs)
326328
else:
327-
xs_b = get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True)
329+
logger.info(f"Splitting tokenized sequences into batches...")
330+
xs_b = torch.from_numpy(get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True))
328331
desc = f"Inference: {inference_type} batch (size={batch_size}) processing ({device.upper()})'"
332+
#print(desc, "xs_b.shape", xs_b.shape)
329333

330334
kwargs = {}
331335
if mask_token_id is not None:
332336
kwargs["mask_token_id"] = mask_token_id
333337

334338
if wt_structure_input_ids is not None:
335339
kwargs["ss_input_ids"] = wt_structure_input_ids.to(device)
340+
341+
#print('xs_b.shape', xs_b.shape, 'xs_b[0]', xs_b[0])
336342

337343
pbar = tqdm(
338344
range(len(xs_b)),
@@ -342,7 +348,7 @@ def plm_inference(
342348

343349
for i in pbar:
344350
pll = inference_function(
345-
tokenized_sequences=torch.tensor(xs_b[i]),
351+
tokenized_sequences=xs_b[i],
346352
wt_input_ids=wt_input_ids,
347353
attention_mask=attention_mask,
348354
model=model,
@@ -361,7 +367,7 @@ def plm_train(
361367
loss_fn,
362368
model,
363369
optimizer,
364-
input_ids,
370+
wt_input_ids,
365371
attention_mask,
366372
batch_size: int = 5,
367373
n_epochs=50,
@@ -382,14 +388,19 @@ def plm_train(
382388
torch.manual_seed(seed)
383389
if device is None:
384390
device = get_device()
385-
logger.info(f"ProSST training using {device.upper()} device "
386-
f"(N_Train={len(torch.flatten(score_batches))})...")
387-
x_sequences_batched = get_batches(x_sequences, dtype=int, batch_size=batch_size,
388-
keep_remaining=False, verbose=True)
391+
print(f"Model training using {device.upper()} device "
392+
f"(N_Train={len(scores)})...")
393+
scores_batched = torch.from_numpy(
394+
get_batches(scores, dtype=float, batch_size=batch_size,
395+
keep_remaining=False, verbose=True)
396+
)
397+
x_sequences_batched = torch.from_numpy(
398+
get_batches(x_sequences, dtype=int, batch_size=batch_size,
399+
keep_remaining=False, verbose=True)
400+
)
389401
x_sequences_batched = x_sequences_batched.to(device)
390-
score_batches = get_batches(scores, dtype=float, batch_size=batch_size,
391-
keep_remaining=False, verbose=True)
392-
score_batches = score_batches.to(device)
402+
#print('x_sequences_batched.shape:', x_sequences_batched.shape)
403+
scores_batched = scores_batched.to(device)
393404
pbar_epochs = tqdm(range(1, n_epochs + 1), disable=not verbose)
394405
epoch_spearman_1 = -1.0
395406
did_not_improve_counter = 0
@@ -404,16 +415,20 @@ def plm_train(
404415
model.train()
405416
y_preds_detached = []
406417
pbar_batches = tqdm(
407-
zip(x_sequences_batched, score_batches),
418+
zip(x_sequences_batched, scores_batched),
408419
total=len(x_sequences), leave=False, disable=not verbose
409420
)
410421
for batch, (seqs_b, scores_b) in enumerate(pbar_batches):
411422
if abort_cb and abort_cb():
412423
return
424+
if seqs_b.dim() == 2:
425+
seqs_b = seqs_b.unsqueeze(0) # e.g., (5, 400) -> (1, 5 400)
413426
y_preds_b = plm_inference(
414-
seqs_b, model, input_ids, attention_mask,
415-
train=True, verbose=False
427+
xs=seqs_b,
428+
wt_input_ids=wt_input_ids, attention_mask=attention_mask,
429+
model=model, train=True, batch_size=None, verbose=False
416430
)
431+
#print('y_preds_b.shape', y_preds_b.shape, y_preds_b)
417432
y_preds_detached.append(y_preds_b.detach().cpu().numpy().flatten())
418433
loss = loss_fn(scores_b, y_preds_b) / n_batch_grad_accumulations
419434
if progress_cb:
@@ -428,12 +443,12 @@ def plm_train(
428443
f"sequence: {(batch + 1) * len(seqs_b):>5d}/{len(x_sequences) * len(seqs_b)}] "
429444
f"({device.upper()})"
430445
)
431-
epoch_spearman_2 = spearmanr(score_batches.cpu().numpy().flatten(),
446+
epoch_spearman_2 = spearmanr(scores_batched.cpu().numpy().flatten(),
432447
np.array(y_preds_detached).flatten())[0]
433448
if epoch_spearman_2 == np.nan:
434449
raise SystemError(
435450
f"No correlation between Y_true and Y_pred could be computed...\n"
436-
f"Y_true: {score_batches.cpu().numpy().flatten()}, "
451+
f"Y_true: {scores_batched.cpu().numpy().flatten()}, "
437452
f"Y_pred: {np.array(y_preds_detached)}"
438453
)
439454
if epoch_spearman_2 > epoch_spearman_1 or epoch == 0:
@@ -444,7 +459,7 @@ def plm_train(
444459
best_model_epoch = epoch
445460
best_model_perf = epoch_spearman_2
446461
best_model = (
447-
f"model_saves/Epoch{epoch}-Ntrain{len(score_batches.cpu().numpy().flatten())}"
462+
f"model_saves/Epoch{epoch}-Ntrain{len(scores_batched.cpu().numpy().flatten())}"
448463
f"-SpearCorr{epoch_spearman_2:.3f}.pt"
449464
)
450465
checkpoint(model, best_model)
@@ -456,7 +471,7 @@ def plm_train(
456471
logger.info(f'\nEarly stop at epoch {epoch}...')
457472
break
458473
loss_total = loss_fn(
459-
torch.flatten(score_batches).to('cpu'),
474+
torch.flatten(scores_batched).to('cpu'),
460475
torch.flatten(torch.Tensor(np.array(y_preds_detached).flatten()))
461476
)
462477
pbar_epochs.set_description(
@@ -586,6 +601,18 @@ def inference(
586601
return y_test_pred
587602

588603

604+
def tokenize_sequences(sequences, tokenizer, max_length, verbose=True):
605+
tokenized_sequences = []
606+
for seq in tqdm(sequences, desc='Tokenizing sequences', disable=not verbose):
607+
encoded_sequence, attention_mask = tokenizer(
608+
seq,
609+
padding='max_length',
610+
truncation=True, # False for not uniform length distribution (truncation)
611+
max_length=max_length
612+
).values()
613+
tokenized_sequences.append(encoded_sequence)
614+
return tokenized_sequences, attention_mask
615+
589616

590617
def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True):
591618
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
@@ -655,7 +682,7 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose
655682
'x_llm' : x_llm_train_prosst,
656683
'llm_attention_mask': prosst_attention_mask,
657684
'llm_vocab': prosst_vocab,
658-
'input_ids': input_ids,
685+
'wt_input_ids': input_ids,
659686
'structure_input_ids': structure_input_ids,
660687
'llm_tokenizer': prosst_tokenizer
661688
}

pypef/plm/prosst_lora_tune.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,12 @@
2020
from scipy.stats import spearmanr
2121
from tqdm import tqdm
2222
from peft import LoraConfig, get_peft_model
23-
from Bio import SeqIO, BiopythonParserWarning
23+
from Bio import BiopythonParserWarning
2424
warnings.filterwarnings(action='ignore', category=BiopythonParserWarning)
2525

26-
from pypef.plm.utils import corr_loss
2726
from pypef.plm.prosst_structure.quantizer import PdbQuantizer
2827
from pypef.utils.helpers import get_device
29-
from pypef.plm.esm_lora_tune import tokenize_sequences
3028
from pypef.plm.utils import load_model_and_tokenizer
31-
from pypef.plm.inference import plm_inference
3229

3330

3431
def prosst_simple_vocab_aa_tokenizer(sequences, vocab, verbose=True):

0 commit comments

Comments
 (0)