Skip to content

Commit 663a163

Browse files
committed
Add test for new plm_inference() function
1 parent ce54e9b commit 663a163

File tree

8 files changed

+65
-65
lines changed

8 files changed

+65
-65
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ When incorporating DCA and PLM features, both models are fine-tuned via few-shot
4848
A quick installation of the PyPEF command line framework using PyPI for Linux and Windows and Python >= 3.10 can be performed with:
4949

5050
```bash
51-
pip install -U pypef
52-
# optionally, for GPU support (see requirements section below):
51+
# For GPU support (e.g., using CUDA 12.8, see requirements section below):
5352
# pip install torch --index-url https://download.pytorch.org/whl/cu128
53+
pip install -U pypef
5454
```
5555

5656
After successful installation, PyPEF should work by calling `pypef` in the shell:

pypef/plm/esm_lora_tune.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,14 @@ def esm_infer(xs, attention_mask, model, device: str | None = None, verbose=Fals
143143
return torch.flatten(y_preds_total)
144144

145145

146-
def esm_unmasked_wt_score(
146+
def unmasked_wt_score(
147147
tokenized_sequences,
148148
attention_mask,
149149
wt_input_ids,
150150
model,
151151
train: bool = False,
152-
device=None,
152+
cut_special_tokens: bool = True, # assumption: cut first and last token
153+
device=None,
153154
**kwargs
154155
):
155156
if device is None:
@@ -189,27 +190,29 @@ def esm_unmasked_wt_score(
189190

190191
logits = outputs.logits
191192
logits = logits.squeeze(0) # remove batch dim
192-
#print('logits.shape:', logits.shape)
193193
# Better make sure that special tokens are always removed / masked
194194
# and only pure amino acid sequence tokens are present / unmasked
195-
#logits = logits[1:-1] # drop CLS/EOS
195+
tokenized_seq_len = tokenized_sequences.shape[1]
196+
if cut_special_tokens:
197+
logits = logits[1:-1] # drop CLS/EOS
198+
tokenized_seq_len -= 2
196199
token_probs = torch.log_softmax(logits, dim=-1)
197-
assert len(tokenized_sequences[0]) == token_probs.shape[0], f"{len(tokenized_sequences[0])} != {token_probs.shape[0]}"
198-
#print('token_probs.shape:', token_probs.shape)
199-
200-
for i_s, tokenized_seq in enumerate(tokenized_sequences):
201-
for i_aa, aa in enumerate(tokenized_seq):
202-
# alternative: use Tensor.index_select() function
203-
if i_aa == 0:
204-
seq_log_probs = token_probs[i_aa, aa].reshape(1)
205-
else:
206-
seq_log_probs = torch.cat(
207-
(seq_log_probs, token_probs[i_aa, aa].reshape(1)), 0)
208-
if i_s == 0:
209-
log_probs = torch.sum(torch.Tensor(seq_log_probs)).reshape(1)
210-
else:
211-
log_probs = torch.cat(
212-
(log_probs, torch.sum(torch.Tensor(seq_log_probs)).reshape(1)), 0)
200+
assert tokenized_seq_len == token_probs.shape[0], (
201+
f"{tokenized_seq_len} != {token_probs.shape[0]}")
202+
203+
log_probs = []
204+
for tokenized_seq in tokenized_sequences:
205+
if cut_special_tokens:
206+
tokenized_seq = tokenized_seq[1:-1]
207+
208+
seq_lp = token_probs[
209+
torch.arange(tokenized_seq.shape[0], device=tokenized_seq.device),
210+
tokenized_seq
211+
].sum(dtype=torch.float64)
212+
213+
log_probs.append(seq_lp)
214+
215+
log_probs = torch.stack(log_probs)
213216
return log_probs
214217

215218

@@ -285,7 +288,7 @@ def esm_mutation_all_pos_masked_pll(
285288
verbose: bool = False,
286289
):
287290
"""
288-
Correct mutation-only pseudo-log-likelihood for ONE sequence.
291+
Correct mutation-only pseudo-log-likelihood for sequences.
289292
"""
290293
model.eval()
291294

@@ -332,13 +335,14 @@ def esm_mutation_all_pos_masked_pll(
332335
return plls
333336

334337

335-
def esm_infer_pll(
338+
def plm_inference(
336339
xs,
337340
wt_input_ids,
338341
attention_mask,
339342
model,
340343
mask_token_id,
341344
inference_type='unmasked',
345+
wt_structure_input_ids=None,
342346
batch_size=5,
343347
train=False,
344348
device=None,
@@ -354,23 +358,19 @@ def esm_infer_pll(
354358

355359
if not isinstance(attention_mask, torch.Tensor):
356360
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
357-
wt_structure_input_ids = None
358361
if inference_type == 'mutation-masking':
359362
inference_function = esm_mutation_only_mutation_masked_pll
360363
elif inference_type in ['full-masking', 'all-pos-masking']:
361364
inference_function = esm_mutation_all_pos_masked_pll
362365
elif inference_type in ['unmasked', 'wt-marginals']:
363-
inference_function = esm_unmasked_wt_score
364-
elif inference_type == 'prosst':
365-
wt_input_ids, wt_structure_input_ids = wt_input_ids
366-
inference_function = esm_unmasked_wt_score
366+
inference_function = unmasked_wt_score
367367
else:
368-
raise SystemError("Choose between 'mutation_masking', 'unmasked', and 'full_masking'")
368+
raise SystemError("Choose between 'mutation-masking', 'unmasked', and 'full-masking'")
369369

370370
scores = []
371371

372372
xs_b = get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True)
373-
desc = f"ESM inference: {inference_type} batch (size={batch_size}) processing ({device.upper()})'"
373+
desc = f"Inference: {inference_type} batch (size={batch_size}) processing ({device.upper()})'"
374374

375375
pbar = tqdm(
376376
range(len(xs_b)),

pypef/plm/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pypef.utils.helpers import get_device
1010
from pypef.plm.utils import get_batches
1111
from pypef.plm.esm_lora_tune import esm_infer, esm_setup, tokenize_sequences
12-
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences, prosst_infer
12+
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_simple_vocab_aa_tokenizer, prosst_infer
1313

1414
import logging
1515
logger = logging.getLogger('pypef.llm.inference')
@@ -26,7 +26,7 @@ def llm_tokenizer(llm_dict, seqs, verbose=True):
2626
max_length=len(seqs[0]), verbose=verbose
2727
)
2828
elif list(llm_dict.keys())[0] == 'prosst':
29-
x_llm_seqs = prosst_tokenize_sequences(
29+
x_llm_seqs = prosst_simple_vocab_aa_tokenizer(
3030
seqs, vocab=llm_dict['prosst']['llm_vocab'], verbose=verbose
3131
)
3232
else:
@@ -80,4 +80,4 @@ def inference(
8080
).cpu()
8181
else:
8282
raise RuntimeError("Unknown LLM option.")
83-
return y_test_pred
83+
return y_test_pred

pypef/plm/prosst_lora_tune.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,22 @@
2929
from pypef.plm.utils import load_model_and_tokenizer
3030

3131

32-
def prosst_tokenize_sequences(sequences, vocab, verbose=True):
32+
def prosst_simple_vocab_aa_tokenizer(sequences, vocab, verbose=True):
3333
print(vocab)
3434
sequences = np.atleast_1d(sequences).tolist()
3535
x_sequences = []
3636
for sequence in tqdm(
3737
sequences, desc='Tokenizing sequences for ProSST modeling',
3838
disable=not verbose
3939
):
40-
x_sequence = [vocab['<cls>']]
40+
#x_sequence = [vocab['<cls>']]
41+
x_sequence = []
4142
for aa in sequence:
4243
try:
4344
x_sequence.append(vocab[aa])
4445
except KeyError:
4546
x_sequence.append(vocab['<unk>'])
46-
x_sequence.append(vocab['<eos>'])
47+
#x_sequence.append(vocab['<eos>'])
4748
x_sequences.append(x_sequence)
4849
return torch.Tensor(x_sequences).to(torch.int)
4950

@@ -296,7 +297,7 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose
296297
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
297298
pdb_file, prosst_tokenizer, wt_seq, verbose=verbose
298299
)
299-
x_llm_train_prosst = prosst_tokenize_sequences(
300+
x_llm_train_prosst = prosst_simple_vocab_aa_tokenizer(
300301
sequences=sequences, vocab=prosst_vocab, verbose=verbose
301302
)
302303
llm_dict_prosst = {

scripts/ProteinGym_runs/official/benchmark_runs/pgym_cv_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
from pypef.utils.variant_data import get_mismatches
20-
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences
20+
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_simple_vocab_aa_tokenizer
2121
from pypef.plm.esm_lora_tune import esm_setup, tokenize_sequences
2222
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
2323
from pypef.hybrid.hybrid_model import DCALLMHybridModel
@@ -177,7 +177,7 @@ def main(cfg: DictConfig) -> None:
177177
device='cuda'
178178
)
179179
vocab = llm_kwargs['prosst']['llm_vocab']
180-
x_llm_test = np.asarray(prosst_tokenize_sequences(
180+
x_llm_test = np.asarray(prosst_simple_vocab_aa_tokenizer(
181181
sequences=s_test, vocab=vocab, verbose=False))
182182
elif llm == "esm1v":
183183
llm_kwargs = esm_setup(sequences=s_train)

scripts/ProteinGym_runs/protgym_hybrid_perf_test_crossval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from pypef.plm.prosst_lora_tune import (
2929
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied,
30-
prosst_tokenize_sequences, prosst_train
30+
prosst_simple_vocab_aa_tokenizer, prosst_train
3131
)
3232
from pypef.plm.inference import inference
3333
from pypef.utils.variant_data import get_seqs_from_var_name
@@ -165,7 +165,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
165165
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
166166
pdb, prosst_tokenizer, wt_seq, verbose=False
167167
)
168-
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab, verbose=False)
168+
x_prosst = prosst_simple_vocab_aa_tokenizer(sequences=sequences, vocab=prosst_vocab, verbose=False)
169169
y_prosst = inference(sequences, 'prosst', pdb_file=pdb, wt_seq=wt_seq, model=prosst_base_model, verbose=False)
170170
print(f'ProSST (unsupervised performance): '
171171
f'{spearmanr(fitnesses, y_prosst.cpu())[0]:.3f}')

scripts/ProteinGym_runs/protgym_hybrid_perf_test_low_n.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from pypef.plm.prosst_lora_tune import (
3131
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied,
32-
prosst_tokenize_sequences, prosst_train
32+
prosst_simple_vocab_aa_tokenizer, prosst_train
3333
)
3434
from pypef.utils.variant_data import get_seqs_from_var_name
3535
from pypef.utils.helpers import get_vram, get_device
@@ -159,7 +159,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
159159
try:
160160
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
161161
pdb, prosst_tokenizer, wt_seq)
162-
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
162+
x_prosst = prosst_simple_vocab_aa_tokenizer(sequences=sequences, vocab=prosst_vocab)
163163
y_prosst = get_logits_from_full_seqs(
164164
x_prosst, prosst_base_model, input_ids, prosst_attention_mask,
165165
structure_input_ids, train=False

tests/test_api_functions.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pypef.ml.regression import AAIndexEncoding, full_aaidx_txt_path, get_regressor_performances
1818
from pypef.dca.gremlin_inference import GREMLIN
1919
from pypef.utils.variant_data import get_sequences_from_file, get_wt_sequence
20-
from pypef.plm.esm_lora_tune import esm_infer, esm_infer_pll, esm_setup, esm_train
20+
from pypef.plm.esm_lora_tune import esm_infer, plm_inference, esm_setup, esm_train
2121
from pypef.plm.prosst_lora_tune import prosst_setup
2222
from pypef.plm.inference import inference, llm_tokenizer
2323
from pypef.hybrid.hybrid_model import DCALLMHybridModel
@@ -26,7 +26,7 @@
2626
)
2727
from pypef.plm.prosst_lora_tune import (
2828
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied,
29-
prosst_tokenize_sequences
29+
prosst_simple_vocab_aa_tokenizer
3030
)
3131
from pypef.utils.helpers import get_device
3232

@@ -258,10 +258,6 @@ def test_plm_corr_blat_ecolx():
258258
prosst_base_model = prosst_base_model.to(device)
259259
df = pd.read_csv(csv_blat_ecolx_stiffler2015)
260260
sequences = df['mutated_sequence'].to_list()
261-
print(sequences[0][23])
262-
print(sequences[1][23])
263-
print('len(sequences[0]):', len(sequences[0]))
264-
print('len(blat_ecolx_wt_seq):', len(blat_ecolx_wt_seq))
265261
y_true = df['DMS_score'].to_list()
266262
for x in ['facebook/esm1v_t33_650M_UR90S_3']:
267263
esm_base_model, _esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models(model=x)
@@ -275,7 +271,7 @@ def test_plm_corr_blat_ecolx():
275271
max_length=len(blat_ecolx_wt_seq) + 2
276272
)
277273
wt_tokens = torch.tensor(wt_tokens[0], dtype=torch.long) # shape (L,)
278-
y_esm = esm_infer_pll(
274+
y_esm = plm_inference(
279275
xs=x_esm,
280276
wt_input_ids=wt_tokens,
281277
attention_mask=esm_attention_mask,
@@ -289,7 +285,7 @@ def test_plm_corr_blat_ecolx():
289285
print(f'{x}: ESM1v (unsupervised performance): '
290286
f'{spearmanr(y_true, y_esm.cpu())[0]}')
291287
np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6367826285982324, decimal=6)
292-
y_esm = esm_infer_pll(
288+
y_esm = plm_inference(
293289
xs=x_esm,
294290
wt_input_ids=wt_tokens,
295291
attention_mask=esm_attention_mask,
@@ -303,7 +299,7 @@ def test_plm_corr_blat_ecolx():
303299
print(f'{x}: ESM1v (unsupervised performance): '
304300
f'{spearmanr(y_true, y_esm.cpu())[0]}')
305301
np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6498987261125897, decimal=6)
306-
#y_esm = esm_infer_pll(
302+
#y_esm = plm_inference(
307303
# xs=x_esm,
308304
# wt_input_ids=wt_tokens,
309305
# attention_mask=esm_attention_mask,
@@ -317,31 +313,34 @@ def test_plm_corr_blat_ecolx():
317313
#print(f'{x}: ESM1v (unsupervised performance): '
318314
# f'{spearmanr(y_true, y_esm.cpu())[0]}')
319315
#np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.666666666666666, decimal=6)
320-
321316
wt_input_ids, prosst_attention_mask, wt_structure_input_ids = get_structure_quantizied(
322317
pdb_blat_ecolx, prosst_tokenizer, blat_ecolx_wt_seq)
323-
x_prosst = tokenize_sequences(sequences=sequences, tokenizer=prosst_tokenizer)
324-
y_prosst = get_logits_from_full_seqs(
325-
x_prosst, prosst_base_model, wt_input_ids, prosst_attention_mask,
326-
wt_structure_input_ids, train=False, verbose=True
318+
x_prosst2 = prosst_simple_vocab_aa_tokenizer(sequences, prosst_vocab)
319+
x_prosst, prosst_attention_mask_ = tokenize_sequences(
320+
sequences=sequences,
321+
tokenizer=prosst_tokenizer,
322+
max_length=len(blat_ecolx_wt_seq) + 2
327323
)
328-
print(f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
329-
f'{spearmanr(y_true, y_prosst.cpu())[0]:.3f}')
324+
assert x_prosst[0][1:-1] == x_prosst2.tolist()[0], (
325+
f"{x_prosst[0][1:-1]} != {x_prosst2.tolist()[0]}")
326+
assert prosst_attention_mask.tolist()[0] == prosst_attention_mask_, (
327+
f"{prosst_attention_mask.tolist()[0]} != {prosst_attention_mask_}")
330328

331-
y_prosst = esm_infer_pll(
329+
y_prosst = plm_inference(
332330
xs=x_prosst,
333-
wt_input_ids=(wt_input_ids, wt_structure_input_ids), ## TODO
331+
wt_input_ids=wt_input_ids,
334332
attention_mask=prosst_attention_mask,
335333
model=prosst_base_model,
336334
mask_token_id=prosst_tokenizer.mask_token_id,
337-
inference_type='prosst', ## TODO
335+
inference_type='unmasked',
336+
wt_structure_input_ids=wt_structure_input_ids,
338337
batch_size=5,
339338
train=False,
340339
verbose=True
341340
)
342341
print(f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
343-
f'{spearmanr(y_true, y_prosst.cpu())[0]:.3f}')
344-
# ACTUAL OLD VERSION: 0.743
342+
f'{spearmanr(y_true, y_prosst.cpu())[0]}')
343+
np.testing.assert_almost_equal(spearmanr(y_true, y_prosst.cpu())[0], 0.7430279087189432, decimal=6)
345344

346345

347346

0 commit comments

Comments
 (0)