Skip to content

Commit fd3e25b

Browse files
committed
dev/fail: further test implementation of plm_inference()
1 parent a7d2040 commit fd3e25b

File tree

6 files changed

+129
-91
lines changed

6 files changed

+129
-91
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
from pypef.utils.plot import plot_y_true_vs_y_pred
3838
import pypef.dca.gremlin_inference
3939
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
40-
from pypef.plm.esm_lora_tune import esm_setup, get_esm_models
41-
from pypef.plm.prosst_lora_tune import get_prosst_models, prosst_setup
42-
from pypef.plm.inference import llm_tokenizer, inference
40+
from pypef.plm.esm_lora_tune import get_esm_models
41+
from pypef.plm.prosst_lora_tune import get_prosst_models
42+
from pypef.plm.inference import esm_setup, llm_tokenizer, inference
4343
from pypef.plm.utils import get_batches
4444

4545
# sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and
@@ -84,7 +84,7 @@ def __init__(
8484
self.llm_base_model = llm_model_input['esm1v']['llm_base_model']
8585
self.llm_model = llm_model_input['esm1v']['llm_model']
8686
self.llm_optimizer = llm_model_input['esm1v']['llm_optimizer']
87-
self.llm_train_function = llm_model_input['esm1v']['llm_train_function']
87+
#self.llm_train_function = llm_model_input['esm1v']['llm_train_function']
8888
self.llm_inference_function = llm_model_input['esm1v']['llm_inference_function']
8989
self.llm_loss_function = llm_model_input['esm1v']['llm_loss_function']
9090
self.x_train_llm = llm_model_input['esm1v']['x_llm']
@@ -94,7 +94,7 @@ def __init__(
9494
self.llm_base_model = llm_model_input['prosst']['llm_base_model']
9595
self.llm_model = llm_model_input['prosst']['llm_model']
9696
self.llm_optimizer = llm_model_input['prosst']['llm_optimizer']
97-
self.llm_train_function = llm_model_input['prosst']['llm_train_function']
97+
#self.llm_train_function = llm_model_input['prosst']['llm_train_function']
9898
self.llm_inference_function = llm_model_input['prosst']['llm_inference_function']
9999
self.llm_loss_function = llm_model_input['prosst']['llm_loss_function']
100100
self.x_train_llm = llm_model_input['prosst']['x_llm']

pypef/plm/esm_lora_tune.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -191,28 +191,4 @@ def esm_train(
191191
model.train(False)
192192

193193

194-
def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True):
195-
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
196-
esm_base_model = esm_base_model.to(device)
197-
wt_tokens, _ = tokenize_sequences(
198-
[wt_seq],
199-
esm_tokenizer,
200-
max_length=len(wt_seq) + 2
201-
)
202-
x_esm, esm_attention_mask = tokenize_sequences(
203-
sequences, esm_tokenizer, max_length=len(wt_seq) + 2, verbose=verbose)
204-
llm_dict_esm = {
205-
'esm1v': {
206-
'llm_base_model': esm_base_model,
207-
'llm_model': esm_lora_model,
208-
'llm_optimizer': esm_optimizer,
209-
'llm_train_function': esm_train,
210-
'llm_inference_function': esm_infer,
211-
'llm_loss_function': corr_loss,
212-
'x_llm' : x_esm,
213-
'input_ids': wt_tokens,
214-
'llm_attention_mask': esm_attention_mask,
215-
'llm_tokenizer': esm_tokenizer
216-
}
217-
}
218-
return llm_dict_esm
194+

pypef/plm/inference.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from tqdm import tqdm
1111

1212
from pypef.utils.helpers import get_device
13-
from pypef.plm.utils import get_batches
14-
from pypef.plm.esm_lora_tune import esm_infer, esm_setup, tokenize_sequences
15-
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_simple_vocab_aa_tokenizer, prosst_infer
13+
from pypef.plm.utils import corr_loss, get_batches
14+
from pypef.plm.esm_lora_tune import get_esm_models, tokenize_sequences
1615

1716
import logging
1817
logger = logging.getLogger('pypef.llm.inference')
@@ -427,3 +426,80 @@ def inference(
427426
else:
428427
raise RuntimeError("Unknown LLM option.")
429428
return y_test_pred
429+
430+
431+
432+
def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True):
433+
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
434+
esm_base_model = esm_base_model.to(device)
435+
wt_tokens, _ = tokenize_sequences(
436+
[wt_seq],
437+
esm_tokenizer,
438+
max_length=len(wt_seq) + 2
439+
)
440+
x_esm, esm_attention_mask = tokenize_sequences(
441+
sequences, esm_tokenizer, max_length=len(wt_seq) + 2, verbose=verbose)
442+
llm_dict_esm = {
443+
'esm1v': {
444+
'llm_base_model': esm_base_model,
445+
'llm_model': esm_lora_model,
446+
'llm_optimizer': esm_optimizer,
447+
#'llm_train_function': esm_train,
448+
'llm_inference_function': plm_inference,
449+
'llm_loss_function': corr_loss,
450+
'x_llm' : x_esm,
451+
'input_ids': wt_tokens,
452+
'llm_attention_mask': esm_attention_mask,
453+
'llm_tokenizer': esm_tokenizer
454+
}
455+
}
456+
return llm_dict_esm
457+
458+
459+
def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose: bool = True):
460+
if wt_seq is None:
461+
raise SystemError(
462+
"Running ProSST requires a wild-type sequence "
463+
"FASTA file input for embedding sequences! "
464+
"Specify a FASTA file with the --wt flag."
465+
)
466+
if pdb_file is None:
467+
raise SystemError(
468+
"Running ProSST requires a PDB file input "
469+
"for embedding sequences! Specify a PDB file "
470+
"with the --pdb flag."
471+
)
472+
473+
pdb_seq = str(list(SeqIO.parse(pdb_file, "pdb-atom"))[0].seq)
474+
assert wt_seq == pdb_seq, (
475+
f"Wild-type sequence is not matching PDB-extracted sequence:"
476+
f"\nWT sequence:\n{wt_seq}\nPDB sequence:\n{pdb_seq}"
477+
)
478+
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
479+
prosst_vocab = prosst_tokenizer.get_vocab()
480+
prosst_base_model = prosst_base_model.to(device)
481+
prosst_optimizer = torch.optim.Adam(prosst_lora_model.parameters(), lr=0.0001)
482+
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
483+
pdb_file, prosst_tokenizer, wt_seq, verbose=verbose
484+
)
485+
x_llm_train_prosst, _attention_mask = tokenize_sequences(
486+
sequences=sequences, tokenizer=prosst_tokenizer,
487+
max_length=len(wt_seq) + 2, verbose=verbose
488+
)
489+
llm_dict_prosst = {
490+
'prosst': {
491+
'llm_base_model': prosst_base_model,
492+
'llm_model': prosst_lora_model,
493+
'llm_optimizer': prosst_optimizer,
494+
#'llm_train_function': prosst_train,
495+
'llm_inference_function': plm_inference, # prosst_infer,
496+
'llm_loss_function': corr_loss,
497+
'x_llm' : x_llm_train_prosst,
498+
'llm_attention_mask': prosst_attention_mask,
499+
'llm_vocab': prosst_vocab,
500+
'input_ids': input_ids,
501+
'structure_input_ids': structure_input_ids,
502+
'llm_tokenizer': prosst_tokenizer
503+
}
504+
}
505+
return llm_dict_prosst

pypef/plm/prosst_lora_tune.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pypef.utils.helpers import get_device
2929
from pypef.plm.esm_lora_tune import tokenize_sequences
3030
from pypef.plm.utils import load_model_and_tokenizer
31+
from pypef.plm.inference import plm_inference
3132

3233

3334
def prosst_simple_vocab_aa_tokenizer(sequences, vocab, verbose=True):
@@ -272,50 +273,3 @@ def get_structure_quantizied(pdb_file, tokenizer, wt_seq, verbose: bool = True):
272273
return input_ids, attention_mask, structure_input_ids
273274

274275

275-
def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose: bool = True):
276-
if wt_seq is None:
277-
raise SystemError(
278-
"Running ProSST requires a wild-type sequence "
279-
"FASTA file input for embedding sequences! "
280-
"Specify a FASTA file with the --wt flag."
281-
)
282-
if pdb_file is None:
283-
raise SystemError(
284-
"Running ProSST requires a PDB file input "
285-
"for embedding sequences! Specify a PDB file "
286-
"with the --pdb flag."
287-
)
288-
289-
pdb_seq = str(list(SeqIO.parse(pdb_file, "pdb-atom"))[0].seq)
290-
assert wt_seq == pdb_seq, (
291-
f"Wild-type sequence is not matching PDB-extracted sequence:"
292-
f"\nWT sequence:\n{wt_seq}\nPDB sequence:\n{pdb_seq}"
293-
)
294-
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
295-
prosst_vocab = prosst_tokenizer.get_vocab()
296-
prosst_base_model = prosst_base_model.to(device)
297-
prosst_optimizer = torch.optim.Adam(prosst_lora_model.parameters(), lr=0.0001)
298-
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
299-
pdb_file, prosst_tokenizer, wt_seq, verbose=verbose
300-
)
301-
x_llm_train_prosst, _attention_mask = tokenize_sequences(
302-
sequences=sequences, tokenizer=prosst_tokenizer,
303-
max_length=len(wt_seq) + 2, verbose=verbose
304-
)
305-
llm_dict_prosst = {
306-
'prosst': {
307-
'llm_base_model': prosst_base_model,
308-
'llm_model': prosst_lora_model,
309-
'llm_optimizer': prosst_optimizer,
310-
'llm_train_function': prosst_train,
311-
'llm_inference_function': prosst_infer,
312-
'llm_loss_function': corr_loss,
313-
'x_llm' : x_llm_train_prosst,
314-
'llm_attention_mask': prosst_attention_mask,
315-
'llm_vocab': prosst_vocab,
316-
'input_ids': input_ids,
317-
'structure_input_ids': structure_input_ids,
318-
'llm_tokenizer': prosst_tokenizer
319-
}
320-
}
321-
return llm_dict_prosst

scripts/ProteinGym_runs/official/benchmark_runs/pgym_cv_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
from pypef.utils.variant_data import get_mismatches
2020
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_simple_vocab_aa_tokenizer
21-
from pypef.plm.esm_lora_tune import esm_setup, tokenize_sequences
21+
from pypef.plm.esm_lora_tune import tokenize_sequences
2222
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
2323
from pypef.hybrid.hybrid_model import DCALLMHybridModel
24+
from pypef.plm.inference import esm_setup
2425

2526

2627
@hydra.main(version_base=None, config_path="../configs", config_name="proteingym_data_setup")

tests/test_api_functions.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
from pypef.dca.gremlin_inference import GREMLIN
1818
from pypef.utils.variant_data import get_sequences_from_file, get_wt_sequence
1919
from pypef.plm.inference import plm_inference
20-
from pypef.plm.esm_lora_tune import esm_setup
21-
from pypef.plm.prosst_lora_tune import prosst_setup
22-
from pypef.plm.inference import inference, llm_tokenizer
20+
from pypef.plm.inference import esm_setup, prosst_setup, llm_tokenizer
2321
from pypef.hybrid.hybrid_model import DCALLMHybridModel
2422
from pypef.plm.esm_lora_tune import (
2523
get_esm_models, tokenize_sequences,
@@ -115,16 +113,49 @@ def test_hybrid_model_dca_llm():
115113
)
116114
assert len(train_seqs_aneh[0]) == len(g.wt_seq)
117115
aneh_wt_seq = get_wt_sequence(wt_seq_file_aneh)
118-
y_pred_esm = inference(train_seqs_aneh, 'esm', wt_seq=aneh_wt_seq)
116+
#y_pred_esm = inference(train_seqs_aneh, 'esm', wt_seq=aneh_wt_seq)
117+
118+
esm_base_model, _esm_lora_model, esm_tokenizer, _esm_optimizer = get_esm_models(
119+
model='facebook/esm1v_t33_650M_UR90S_3')
120+
esm_base_model = esm_base_model.to(get_device())
121+
x_esm, esm_attention_mask = tokenize_sequences(
122+
train_seqs_aneh, esm_tokenizer, max_length=len(wt_seq_file_aneh) + 2)
123+
# Tokenize WT sequence once
124+
wt_tokens, _ = tokenize_sequences(
125+
[aneh_wt_seq],
126+
esm_tokenizer,
127+
max_length=len(aneh_wt_seq) + 2
128+
)
129+
wt_tokens = torch.tensor(wt_tokens[0], dtype=torch.long) # shape (L,)
130+
print(wt_tokens.shape)
131+
print(esm_attention_mask.shape)
132+
print(x_esm.shape)
133+
134+
y_pred_esm = plm_inference(xs=x_esm, wt_input_ids=wt_tokens,
135+
attention_mask=esm_attention_mask, model=esm_base_model)
119136
np.testing.assert_almost_equal(
120137
spearmanr(train_ys_aneh, y_pred_esm)[0],
121138
-0.713214007088901,
122139
decimal=7
123140
)
124-
y_pred_prosst = inference(
125-
train_seqs_aneh, 'prosst',
126-
pdb_file=pdb_file_aneh, wt_seq=aneh_wt_seq
141+
142+
#y_pred_prosst = inference(
143+
# train_seqs_aneh, 'prosst',
144+
# pdb_file=pdb_file_aneh, wt_seq=aneh_wt_seq
145+
#)
146+
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
147+
prosst_vocab = prosst_tokenizer.get_vocab()
148+
prosst_base_model = prosst_base_model.to(get_device())
149+
wt_input_ids, prosst_attention_mask, wt_structure_input_ids = get_structure_quantizied(
150+
pdb_blat_ecolx, prosst_tokenizer, aneh_wt_seq)
151+
x_prosst, prosst_attention_mask_ = tokenize_sequences(
152+
sequences=train_seqs_aneh,
153+
tokenizer=prosst_tokenizer,
154+
max_length=len(wt_seq_file_aneh) + 2
127155
)
156+
y_pred_prosst = plm_inference(xs=x_prosst, wt_input_ids=wt_input_ids,
157+
attention_mask=prosst_attention_mask, model=prosst_base_model,
158+
wt_structure_input_ids=wt_structure_input_ids)
128159
np.testing.assert_almost_equal(
129160
spearmanr(train_ys_aneh, y_pred_prosst)[0],
130161
-0.7394433335146882,
@@ -383,8 +414,8 @@ def test_plm_corr_blat_ecolx():
383414

384415

385416
if __name__ == "__main__":
386-
#test_gremlin_avgfp()
387-
#test_hybrid_model_dca_llm()
388-
#test_dataset_b_results()
417+
test_gremlin_avgfp()
418+
test_hybrid_model_dca_llm()
419+
test_dataset_b_results()
389420
test_plm_corr_blat_ecolx()
390421

0 commit comments

Comments
 (0)