|
10 | 10 | from tqdm import tqdm |
11 | 11 |
|
12 | 12 | from pypef.utils.helpers import get_device |
13 | | -from pypef.plm.utils import get_batches |
14 | | -from pypef.plm.esm_lora_tune import esm_infer, esm_setup, tokenize_sequences |
15 | | -from pypef.plm.prosst_lora_tune import prosst_setup, prosst_simple_vocab_aa_tokenizer, prosst_infer |
| 13 | +from pypef.plm.utils import corr_loss, get_batches |
| 14 | +from pypef.plm.esm_lora_tune import get_esm_models, tokenize_sequences |
16 | 15 |
|
17 | 16 | import logging |
18 | 17 | logger = logging.getLogger('pypef.llm.inference') |
@@ -427,3 +426,80 @@ def inference( |
427 | 426 | else: |
428 | 427 | raise RuntimeError("Unknown LLM option.") |
429 | 428 | return y_test_pred |
| 429 | + |
| 430 | + |
| 431 | + |
| 432 | +def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True): |
| 433 | + esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models() |
| 434 | + esm_base_model = esm_base_model.to(device) |
| 435 | + wt_tokens, _ = tokenize_sequences( |
| 436 | + [wt_seq], |
| 437 | + esm_tokenizer, |
| 438 | + max_length=len(wt_seq) + 2 |
| 439 | + ) |
| 440 | + x_esm, esm_attention_mask = tokenize_sequences( |
| 441 | + sequences, esm_tokenizer, max_length=len(wt_seq) + 2, verbose=verbose) |
| 442 | + llm_dict_esm = { |
| 443 | + 'esm1v': { |
| 444 | + 'llm_base_model': esm_base_model, |
| 445 | + 'llm_model': esm_lora_model, |
| 446 | + 'llm_optimizer': esm_optimizer, |
| 447 | + #'llm_train_function': esm_train, |
| 448 | + 'llm_inference_function': plm_inference, |
| 449 | + 'llm_loss_function': corr_loss, |
| 450 | + 'x_llm' : x_esm, |
| 451 | + 'input_ids': wt_tokens, |
| 452 | + 'llm_attention_mask': esm_attention_mask, |
| 453 | + 'llm_tokenizer': esm_tokenizer |
| 454 | + } |
| 455 | + } |
| 456 | + return llm_dict_esm |
| 457 | + |
| 458 | + |
| 459 | +def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose: bool = True): |
| 460 | + if wt_seq is None: |
| 461 | + raise SystemError( |
| 462 | + "Running ProSST requires a wild-type sequence " |
| 463 | + "FASTA file input for embedding sequences! " |
| 464 | + "Specify a FASTA file with the --wt flag." |
| 465 | + ) |
| 466 | + if pdb_file is None: |
| 467 | + raise SystemError( |
| 468 | + "Running ProSST requires a PDB file input " |
| 469 | + "for embedding sequences! Specify a PDB file " |
| 470 | + "with the --pdb flag." |
| 471 | + ) |
| 472 | + |
| 473 | + pdb_seq = str(list(SeqIO.parse(pdb_file, "pdb-atom"))[0].seq) |
| 474 | + assert wt_seq == pdb_seq, ( |
| 475 | + f"Wild-type sequence is not matching PDB-extracted sequence:" |
| 476 | + f"\nWT sequence:\n{wt_seq}\nPDB sequence:\n{pdb_seq}" |
| 477 | + ) |
| 478 | + prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models() |
| 479 | + prosst_vocab = prosst_tokenizer.get_vocab() |
| 480 | + prosst_base_model = prosst_base_model.to(device) |
| 481 | + prosst_optimizer = torch.optim.Adam(prosst_lora_model.parameters(), lr=0.0001) |
| 482 | + input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied( |
| 483 | + pdb_file, prosst_tokenizer, wt_seq, verbose=verbose |
| 484 | + ) |
| 485 | + x_llm_train_prosst, _attention_mask = tokenize_sequences( |
| 486 | + sequences=sequences, tokenizer=prosst_tokenizer, |
| 487 | + max_length=len(wt_seq) + 2, verbose=verbose |
| 488 | + ) |
| 489 | + llm_dict_prosst = { |
| 490 | + 'prosst': { |
| 491 | + 'llm_base_model': prosst_base_model, |
| 492 | + 'llm_model': prosst_lora_model, |
| 493 | + 'llm_optimizer': prosst_optimizer, |
| 494 | + #'llm_train_function': prosst_train, |
| 495 | + 'llm_inference_function': plm_inference, # prosst_infer, |
| 496 | + 'llm_loss_function': corr_loss, |
| 497 | + 'x_llm' : x_llm_train_prosst, |
| 498 | + 'llm_attention_mask': prosst_attention_mask, |
| 499 | + 'llm_vocab': prosst_vocab, |
| 500 | + 'input_ids': input_ids, |
| 501 | + 'structure_input_ids': structure_input_ids, |
| 502 | + 'llm_tokenizer': prosst_tokenizer |
| 503 | + } |
| 504 | + } |
| 505 | + return llm_dict_prosst |
0 commit comments