1313from tqdm import tqdm
1414from Bio import SeqIO
1515
16+ from pypef .plm .prosst_lora_tune import get_prosst_models , get_structure_quantizied
1617from pypef .utils .helpers import get_device
1718from pypef .plm .utils import corr_loss , get_batches
18- from pypef .plm .esm_lora_tune import get_esm_models , tokenize_sequences
19+ from pypef .plm .esm_lora_tune import get_esm_models
1920
2021
2122import logging
@@ -55,6 +56,7 @@ def unmasked_wt_score(
5556 verbose : bool = False ,
5657 ** model_kwargs
5758 ):
59+ #print('unmasked_wt_score() tokenized_sequences.shape', tokenized_sequences.shape)
5860 if device is None :
5961 device = get_device ()
6062 if wt_input_ids .dim () == 1 :
@@ -322,17 +324,21 @@ def plm_inference(
322324
323325 scores = []
324326 if batch_size is None :
325- xs_b = xs
327+ xs_b = torch . atleast_2d ( xs )
326328 else :
327- xs_b = get_batches (xs , dtype = int , batch_size = batch_size , keep_remaining = True , verbose = True )
329+ logger .info (f"Splitting tokenized sequences into batches..." )
330+ xs_b = torch .from_numpy (get_batches (xs , dtype = int , batch_size = batch_size , keep_remaining = True , verbose = True ))
328331 desc = f"Inference: { inference_type } batch (size={ batch_size } ) processing ({ device .upper ()} )'"
332+ #print(desc, "xs_b.shape", xs_b.shape)
329333
330334 kwargs = {}
331335 if mask_token_id is not None :
332336 kwargs ["mask_token_id" ] = mask_token_id
333337
334338 if wt_structure_input_ids is not None :
335339 kwargs ["ss_input_ids" ] = wt_structure_input_ids .to (device )
340+
341+ #print('xs_b.shape', xs_b.shape, 'xs_b[0]', xs_b[0])
336342
337343 pbar = tqdm (
338344 range (len (xs_b )),
@@ -342,7 +348,7 @@ def plm_inference(
342348
343349 for i in pbar :
344350 pll = inference_function (
345- tokenized_sequences = torch . tensor ( xs_b [i ]) ,
351+ tokenized_sequences = xs_b [i ],
346352 wt_input_ids = wt_input_ids ,
347353 attention_mask = attention_mask ,
348354 model = model ,
@@ -361,7 +367,7 @@ def plm_train(
361367 loss_fn ,
362368 model ,
363369 optimizer ,
364- input_ids ,
370+ wt_input_ids ,
365371 attention_mask ,
366372 batch_size : int = 5 ,
367373 n_epochs = 50 ,
@@ -382,14 +388,19 @@ def plm_train(
382388 torch .manual_seed (seed )
383389 if device is None :
384390 device = get_device ()
385- logger .info (f"ProSST training using { device .upper ()} device "
386- f"(N_Train={ len (torch .flatten (score_batches ))} )..." )
387- x_sequences_batched = get_batches (x_sequences , dtype = int , batch_size = batch_size ,
388- keep_remaining = False , verbose = True )
391+ print (f"Model training using { device .upper ()} device "
392+ f"(N_Train={ len (scores )} )..." )
393+ scores_batched = torch .from_numpy (
394+ get_batches (scores , dtype = float , batch_size = batch_size ,
395+ keep_remaining = False , verbose = True )
396+ )
397+ x_sequences_batched = torch .from_numpy (
398+ get_batches (x_sequences , dtype = int , batch_size = batch_size ,
399+ keep_remaining = False , verbose = True )
400+ )
389401 x_sequences_batched = x_sequences_batched .to (device )
390- score_batches = get_batches (scores , dtype = float , batch_size = batch_size ,
391- keep_remaining = False , verbose = True )
392- score_batches = score_batches .to (device )
402+ #print('x_sequences_batched.shape:', x_sequences_batched.shape)
403+ scores_batched = scores_batched .to (device )
393404 pbar_epochs = tqdm (range (1 , n_epochs + 1 ), disable = not verbose )
394405 epoch_spearman_1 = - 1.0
395406 did_not_improve_counter = 0
@@ -404,16 +415,20 @@ def plm_train(
404415 model .train ()
405416 y_preds_detached = []
406417 pbar_batches = tqdm (
407- zip (x_sequences_batched , score_batches ),
418+ zip (x_sequences_batched , scores_batched ),
408419 total = len (x_sequences ), leave = False , disable = not verbose
409420 )
410421 for batch , (seqs_b , scores_b ) in enumerate (pbar_batches ):
411422 if abort_cb and abort_cb ():
412423 return
424+ if seqs_b .dim () == 2 :
425+ seqs_b = seqs_b .unsqueeze (0 ) # e.g., (5, 400) -> (1, 5 400)
413426 y_preds_b = plm_inference (
414- seqs_b , model , input_ids , attention_mask ,
415- train = True , verbose = False
427+ xs = seqs_b ,
428+ wt_input_ids = wt_input_ids , attention_mask = attention_mask ,
429+ model = model , train = True , batch_size = None , verbose = False
416430 )
431+ #print('y_preds_b.shape', y_preds_b.shape, y_preds_b)
417432 y_preds_detached .append (y_preds_b .detach ().cpu ().numpy ().flatten ())
418433 loss = loss_fn (scores_b , y_preds_b ) / n_batch_grad_accumulations
419434 if progress_cb :
@@ -428,12 +443,12 @@ def plm_train(
428443 f"sequence: { (batch + 1 ) * len (seqs_b ):>5d} /{ len (x_sequences ) * len (seqs_b )} ] "
429444 f"({ device .upper ()} )"
430445 )
431- epoch_spearman_2 = spearmanr (score_batches .cpu ().numpy ().flatten (),
446+ epoch_spearman_2 = spearmanr (scores_batched .cpu ().numpy ().flatten (),
432447 np .array (y_preds_detached ).flatten ())[0 ]
433448 if epoch_spearman_2 == np .nan :
434449 raise SystemError (
435450 f"No correlation between Y_true and Y_pred could be computed...\n "
436- f"Y_true: { score_batches .cpu ().numpy ().flatten ()} , "
451+ f"Y_true: { scores_batched .cpu ().numpy ().flatten ()} , "
437452 f"Y_pred: { np .array (y_preds_detached )} "
438453 )
439454 if epoch_spearman_2 > epoch_spearman_1 or epoch == 0 :
@@ -444,7 +459,7 @@ def plm_train(
444459 best_model_epoch = epoch
445460 best_model_perf = epoch_spearman_2
446461 best_model = (
447- f"model_saves/Epoch{ epoch } -Ntrain{ len (score_batches .cpu ().numpy ().flatten ())} "
462+ f"model_saves/Epoch{ epoch } -Ntrain{ len (scores_batched .cpu ().numpy ().flatten ())} "
448463 f"-SpearCorr{ epoch_spearman_2 :.3f} .pt"
449464 )
450465 checkpoint (model , best_model )
@@ -456,7 +471,7 @@ def plm_train(
456471 logger .info (f'\n Early stop at epoch { epoch } ...' )
457472 break
458473 loss_total = loss_fn (
459- torch .flatten (score_batches ).to ('cpu' ),
474+ torch .flatten (scores_batched ).to ('cpu' ),
460475 torch .flatten (torch .Tensor (np .array (y_preds_detached ).flatten ()))
461476 )
462477 pbar_epochs .set_description (
@@ -586,6 +601,18 @@ def inference(
586601 return y_test_pred
587602
588603
604+ def tokenize_sequences (sequences , tokenizer , max_length , verbose = True ):
605+ tokenized_sequences = []
606+ for seq in tqdm (sequences , desc = 'Tokenizing sequences' , disable = not verbose ):
607+ encoded_sequence , attention_mask = tokenizer (
608+ seq ,
609+ padding = 'max_length' ,
610+ truncation = True , # False for not uniform length distribution (truncation)
611+ max_length = max_length
612+ ).values ()
613+ tokenized_sequences .append (encoded_sequence )
614+ return tokenized_sequences , attention_mask
615+
589616
590617def esm_setup (wt_seq , sequences , device : str | None = None , verbose : bool = True ):
591618 esm_base_model , esm_lora_model , esm_tokenizer , esm_optimizer = get_esm_models ()
@@ -655,7 +682,7 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose
655682 'x_llm' : x_llm_train_prosst ,
656683 'llm_attention_mask' : prosst_attention_mask ,
657684 'llm_vocab' : prosst_vocab ,
658- 'input_ids ' : input_ids ,
685+ 'wt_input_ids ' : input_ids ,
659686 'structure_input_ids' : structure_input_ids ,
660687 'llm_tokenizer' : prosst_tokenizer
661688 }
0 commit comments