44# Some helper functions for infernece of different models
55# based on simple/wrapping functions
66
7+ import os
8+ import inspect
79import numpy as np
10+ from scipy .stats import spearmanr
811import torch
912import torch .nn .functional as F
1013from tqdm import tqdm
14+ from Bio import SeqIO
1115
1216from pypef .utils .helpers import get_device
1317from pypef .plm .utils import corr_loss , get_batches
1418from pypef .plm .esm_lora_tune import get_esm_models , tokenize_sequences
1519
20+
1621import logging
1722logger = logging .getLogger ('pypef.llm.inference' )
1823
1924
25+ def checkpoint (model , filename ):
26+ torch .save (model .state_dict (), filename )
27+
28+
29+ def load_model (model , filename ):
30+ logger .info (f'Loading best model: { os .path .abspath (filename )} ...' )
31+ model .load_state_dict (torch .load (filename , weights_only = True ))
32+
33+
2034def tokenize_sequences (sequences , tokenizer , max_length , verbose = True ):
2135 tokenized_sequences = []
2236 for seq in tqdm (sequences , desc = 'Tokenizing sequences' , disable = not verbose ):
@@ -45,25 +59,34 @@ def unmasked_wt_score(
4559 device = get_device ()
4660 if wt_input_ids .dim () == 1 :
4761 wt_input_ids = wt_input_ids .unsqueeze (0 )
62+ wt_input_ids = wt_input_ids .to (device )
4863 #structure_input_ids = model_kwargs.get("structure_input_ids", None)
4964
5065 attention_masks = torch .Tensor (np .full (
51- shape = np .shape (wt_input_ids ), fill_value = attention_mask )).to (torch .int64 )
52- if train :
53- outputs = model (
54- input_ids = wt_input_ids .to (device ),
55- attention_mask = attention_masks .to (device ),
56- ** model_kwargs
57- )
58-
59- else :
60- with torch .no_grad ():
66+ shape = np .shape (wt_input_ids ), fill_value = attention_mask )).to (torch .int64 ).to (device )
67+ try :
68+ if train :
6169 outputs = model (
62- input_ids = wt_input_ids . to ( device ) ,
63- attention_mask = attention_masks . to ( device ) ,
70+ input_ids = wt_input_ids ,
71+ attention_mask = attention_masks ,
6472 ** model_kwargs
6573 )
6674
75+ else :
76+ with torch .no_grad ():
77+ outputs = model (
78+ input_ids = wt_input_ids ,
79+ attention_mask = attention_masks ,
80+ ** model_kwargs
81+ )
82+ except TypeError as e :
83+ print (f"Did not find model input keyword arguments (kwargs: "
84+ f"{ model_kwargs .keys ()} ). Available kawrgs identified from "
85+ f"model.forward function inspect:\n "
86+ f"{ inspect .signature (model .forward )} \n Original error:" )
87+ raise e
88+
89+
6790 logits = outputs .logits
6891 logits = logits .squeeze (0 ) # remove batch dim
6992 # Better make sure that special tokens are always removed / masked
@@ -174,10 +197,8 @@ def mutation_only_mutation_masked_pll(
174197 output_hidden_states = False
175198 )
176199 logits = outputs .logits # (1, L, V)
177-
178200 log_probs = F .log_softmax (logits [0 , pos ], dim = - 1 )
179201 true_token = tokenized_seq [pos ]
180-
181202 pll = pll + log_probs [true_token ]
182203
183204 plls [i ] = pll
@@ -186,8 +207,8 @@ def mutation_only_mutation_masked_pll(
186207
187208
188209def mutation_all_pos_masked_pll (
189- tokenized_sequences : torch .Tensor , # (L,)
190- attention_mask : torch .Tensor , # (L,)
210+ tokenized_sequences : torch .Tensor , # (L,)
211+ attention_mask : torch .Tensor , # (L,)
191212 model ,
192213 mask_token_id : int ,
193214 train : bool = False ,
@@ -275,7 +296,7 @@ def plm_inference(
275296 mask_token_id = None ,
276297 inference_type = 'unmasked' ,
277298 wt_structure_input_ids = None ,
278- batch_size = 5 ,
299+ batch_size : int | None = 5 ,
279300 train = False ,
280301 device = None ,
281302 verbose = False ,
@@ -300,16 +321,18 @@ def plm_inference(
300321 raise SystemError ("Choose between 'mutation-masking', 'unmasked', and 'full-masking'" )
301322
302323 scores = []
303-
304- xs_b = get_batches (xs , dtype = int , batch_size = batch_size , keep_remaining = True , verbose = True )
324+ if batch_size is None :
325+ xs_b = xs
326+ else :
327+ xs_b = get_batches (xs , dtype = int , batch_size = batch_size , keep_remaining = True , verbose = True )
305328 desc = f"Inference: { inference_type } batch (size={ batch_size } ) processing ({ device .upper ()} )'"
306329
307330 kwargs = {}
308331 if mask_token_id is not None :
309332 kwargs ["mask_token_id" ] = mask_token_id
310333
311334 if wt_structure_input_ids is not None :
312- kwargs ["structure_input_ids " ] = wt_structure_input_ids
335+ kwargs ["ss_input_ids " ] = wt_structure_input_ids . to ( device )
313336
314337 pbar = tqdm (
315338 range (len (xs_b )),
@@ -332,6 +355,141 @@ def plm_inference(
332355 return torch .cat (scores )
333356
334357
358+ def plm_train (
359+ x_sequences ,
360+ scores ,
361+ loss_fn ,
362+ model ,
363+ optimizer ,
364+ input_ids ,
365+ attention_mask ,
366+ batch_size : int = 5 ,
367+ n_epochs = 50 ,
368+ device : str | None = None ,
369+ seed : int | None = None ,
370+ early_stop : int = 50 ,
371+ verbose : bool = True ,
372+ wt_structure_input_ids = None ,
373+ n_batch_grad_accumulations : int = 1 ,
374+ raise_error_on_train_fail : bool = True ,
375+ progress_cb = None ,
376+ abort_cb = None
377+ ):
378+ """
379+ TODO: Wrapper function for `plm_inference()` for PLM training.
380+ """
381+ if seed is not None :
382+ torch .manual_seed (seed )
383+ if device is None :
384+ device = get_device ()
385+ logger .info (f"ProSST training using { device .upper ()} device "
386+ f"(N_Train={ len (torch .flatten (score_batches ))} )..." )
387+ x_sequences_batched = get_batches (x_sequences , dtype = int , batch_size = batch_size ,
388+ keep_remaining = False , verbose = True )
389+ x_sequences_batched = x_sequences_batched .to (device )
390+ score_batches = get_batches (scores , dtype = float , batch_size = batch_size ,
391+ keep_remaining = False , verbose = True )
392+ score_batches = score_batches .to (device )
393+ pbar_epochs = tqdm (range (1 , n_epochs + 1 ), disable = not verbose )
394+ epoch_spearman_1 = - 1.0
395+ did_not_improve_counter = 0
396+ best_model = None
397+ best_model_epoch = np .nan
398+ best_model_perf = np .nan
399+ loss = np .nan
400+ os .makedirs ('model_saves' , exist_ok = True )
401+ for epoch in pbar_epochs :
402+ if epoch == 0 :
403+ pbar_epochs .set_description (f'Epoch { epoch } /{ n_epochs } ' )
404+ model .train ()
405+ y_preds_detached = []
406+ pbar_batches = tqdm (
407+ zip (x_sequences_batched , score_batches ),
408+ total = len (x_sequences ), leave = False , disable = not verbose
409+ )
410+ for batch , (seqs_b , scores_b ) in enumerate (pbar_batches ):
411+ if abort_cb and abort_cb ():
412+ return
413+ y_preds_b = plm_inference (
414+ seqs_b , model , input_ids , attention_mask ,
415+ train = True , verbose = False
416+ )
417+ y_preds_detached .append (y_preds_b .detach ().cpu ().numpy ().flatten ())
418+ loss = loss_fn (scores_b , y_preds_b ) / n_batch_grad_accumulations
419+ if progress_cb :
420+ progress_cb (epoch - 1 , batch + 1 , len (pbar_epochs ), len (pbar_batches ), loss )
421+ loss .backward ()
422+ if (batch + 1 ) % n_batch_grad_accumulations == 0 or (batch + 1 ) == len (pbar_batches ):
423+ optimizer .step ()
424+ optimizer .zero_grad ()
425+ pbar_batches .set_description (
426+ f"Epoch: { epoch } . Loss: { loss .detach ():>1f} "
427+ f"[batch: { batch + 1 } /{ len (x_sequences )} | "
428+ f"sequence: { (batch + 1 ) * len (seqs_b ):>5d} /{ len (x_sequences ) * len (seqs_b )} ] "
429+ f"({ device .upper ()} )"
430+ )
431+ epoch_spearman_2 = spearmanr (score_batches .cpu ().numpy ().flatten (),
432+ np .array (y_preds_detached ).flatten ())[0 ]
433+ if epoch_spearman_2 == np .nan :
434+ raise SystemError (
435+ f"No correlation between Y_true and Y_pred could be computed...\n "
436+ f"Y_true: { score_batches .cpu ().numpy ().flatten ()} , "
437+ f"Y_pred: { np .array (y_preds_detached )} "
438+ )
439+ if epoch_spearman_2 > epoch_spearman_1 or epoch == 0 :
440+ if best_model is not None :
441+ if os .path .isfile (best_model ):
442+ os .remove (best_model )
443+ did_not_improve_counter = 0
444+ best_model_epoch = epoch
445+ best_model_perf = epoch_spearman_2
446+ best_model = (
447+ f"model_saves/Epoch{ epoch } -Ntrain{ len (score_batches .cpu ().numpy ().flatten ())} "
448+ f"-SpearCorr{ epoch_spearman_2 :.3f} .pt"
449+ )
450+ checkpoint (model , best_model )
451+ epoch_spearman_1 = epoch_spearman_2
452+ #logger.info(f"Saved current best model as {best_model}")
453+ else :
454+ did_not_improve_counter += 1
455+ if did_not_improve_counter >= early_stop :
456+ logger .info (f'\n Early stop at epoch { epoch } ...' )
457+ break
458+ loss_total = loss_fn (
459+ torch .flatten (score_batches ).to ('cpu' ),
460+ torch .flatten (torch .Tensor (np .array (y_preds_detached ).flatten ()))
461+ )
462+ pbar_epochs .set_description (
463+ f'Epoch { epoch } /{ n_epochs } [SpearCorr: { epoch_spearman_2 :.3f} , Loss: { loss_total :.3f} ] '
464+ f'(Best epoch: { best_model_epoch } : { best_model_perf :.3f} )' )
465+ if progress_cb :
466+ progress_cb (epoch , batch + 1 , len (pbar_epochs ), len (pbar_batches ), loss )
467+ if best_model is None :
468+ msg = ("Failed to train a model (probably due to the input "
469+ "data characteristics and loss/correlation being NaN)." )
470+ if raise_error_on_train_fail :
471+ raise RuntimeError (msg )
472+ else :
473+ logger .warning (f"{ msg } Continuing nonetheless (using failed model "
474+ f"and replacing NaN's with zeros)..." )
475+ #y_preds_train = get_logits_from_full_seqs(
476+ # x_sequences.flatten(start_dim=0, end_dim=1),
477+ # model, input_ids, attention_mask, structure_input_ids, train=False, verbose=False
478+ #)
479+ #y_preds_train[torch.isnan(y_preds_train)] = 0.0
480+ else :
481+ logger .info (f"Loading best model as { best_model } ..." )
482+ load_model (model , best_model )
483+ #y_preds_train = get_logits_from_full_seqs(
484+ # x_sequences.flatten(start_dim=0, end_dim=1),
485+ # model, input_ids, attention_mask, structure_input_ids, train=False, verbose=False
486+ #)
487+ return #y_preds_train.cpu()
488+
489+
490+
491+
492+
335493######################### Deprecated
336494
337495def llm_tokenizer (llm_dict , seqs , verbose = True ):
@@ -444,12 +602,12 @@ def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True
444602 'llm_base_model' : esm_base_model ,
445603 'llm_model' : esm_lora_model ,
446604 'llm_optimizer' : esm_optimizer ,
447- # 'llm_train_function': esm_train ,
605+ 'llm_train_function' : plm_train ,
448606 'llm_inference_function' : plm_inference ,
449607 'llm_loss_function' : corr_loss ,
450- 'x_llm' : x_esm ,
451- 'input_ids ' : wt_tokens ,
452- 'llm_attention_mask ' : esm_attention_mask ,
608+ 'x_llm' : torch . tensor ( x_esm ) ,
609+ 'llm_attention_mask ' : torch . tensor ( esm_attention_mask ) ,
610+ 'wt_input_ids ' : torch . tensor ( wt_tokens ) ,
453611 'llm_tokenizer' : esm_tokenizer
454612 }
455613 }
@@ -491,7 +649,7 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose
491649 'llm_base_model' : prosst_base_model ,
492650 'llm_model' : prosst_lora_model ,
493651 'llm_optimizer' : prosst_optimizer ,
494- # 'llm_train_function': prosst_train ,
652+ 'llm_train_function' : plm_train ,
495653 'llm_inference_function' : plm_inference , # prosst_infer,
496654 'llm_loss_function' : corr_loss ,
497655 'x_llm' : x_llm_train_prosst ,
0 commit comments