Skip to content

Commit ce54e9b

Browse files
committed
dev: todo: More uniform tokenization and plm inference
1 parent 3ef22b3 commit ce54e9b

File tree

8 files changed

+70
-35
lines changed

8 files changed

+70
-35
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
flake8 ./pypef --count --select=E9,F63,F7,F82 --show-source --statistics
3939
- name: Export Pythonpath and run tests using the main script
4040
run: |
41-
export PYTHONPATH="${PYTHONPATH}:${PWD}" && python -m pytest ./tests/ -v -m "not pip_specific"
41+
export PYTHONPATH="${PYTHONPATH}:${PWD}" && python -m pytest ./tests/ -v -m "not (pip_specific or requires_gpu)"
4242
- name: Export Pythonpath and run tests using pip-installation
4343
run: |
4444
export PYTHONPATH=""
@@ -81,7 +81,7 @@ jobs:
8181
- name: Export Pythonpath and run tests using the main script
8282
shell: pwsh
8383
run: |
84-
$env:PYTHONPATH = "${PWD};${env:PYTHONPATH}";python -m pytest .\tests\ -v -m "not pip_specific"
84+
$env:PYTHONPATH = "${PWD};${env:PYTHONPATH}";python -m pytest .\tests\ -v -m "not (pip_specific or requires_gpu)"
8585
- name: Export Pythonpath and run tests using pip-installation
8686
shell: pwsh
8787
run: |

pypef/plm/esm_lora_tune.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_esm_models(model='facebook/esm1v_t33_650M_UR90S_3'):
4848
return base_model, lora_model, tokenizer, optimizer
4949

5050

51-
def esm_tokenize_sequences(sequences, tokenizer, max_length, verbose=True):
51+
def tokenize_sequences(sequences, tokenizer, max_length, verbose=True):
5252
tokenized_sequences = []
5353
for seq in tqdm(sequences, desc='Tokenizing sequences for ESM modeling', disable=not verbose):
5454
encoded_sequence, attention_mask = tokenizer(
@@ -154,18 +154,49 @@ def esm_unmasked_wt_score(
154154
):
155155
if device is None:
156156
device = get_device()
157-
wt_input_ids = wt_input_ids.unsqueeze(0)
157+
if wt_input_ids.dim() == 1:
158+
wt_input_ids = wt_input_ids.unsqueeze(0)
159+
structure_input_ids = kwargs.get("structure_input_ids", None)
158160
attention_masks = torch.Tensor(np.full(
159161
shape=np.shape(wt_input_ids), fill_value=attention_mask)).to(torch.int64)
160162
if train:
161-
outputs = model(wt_input_ids.to(device), attention_masks.to(device),
162-
output_hidden_states=False)
163+
if structure_input_ids is not None:
164+
outputs = model(
165+
input_ids=wt_input_ids.to(device),
166+
attention_mask=attention_masks.to(device),
167+
ss_input_ids=structure_input_ids.to(device)
168+
)
169+
else:
170+
outputs = model(
171+
wt_input_ids.to(device),
172+
attention_masks.to(device),
173+
output_hidden_states=False
174+
)
163175
else:
164176
with torch.no_grad():
165-
outputs = model(wt_input_ids.to(device), attention_masks.to(device),
166-
output_hidden_states=False)
177+
if structure_input_ids is not None:
178+
outputs = model(
179+
input_ids=wt_input_ids.to(device),
180+
attention_mask=attention_masks.to(device),
181+
ss_input_ids=structure_input_ids.to(device)
182+
)
183+
else:
184+
outputs = model(
185+
wt_input_ids.to(device),
186+
attention_masks.to(device),
187+
output_hidden_states=False
188+
)
189+
167190
logits = outputs.logits
168-
token_probs = torch.log_softmax(logits, dim=-1).squeeze(0)
191+
logits = logits.squeeze(0) # remove batch dim
192+
#print('logits.shape:', logits.shape)
193+
# Better make sure that special tokens are always removed / masked
194+
# and only pure amino acid sequence tokens are present / unmasked
195+
#logits = logits[1:-1] # drop CLS/EOS
196+
token_probs = torch.log_softmax(logits, dim=-1)
197+
assert len(tokenized_sequences[0]) == token_probs.shape[0], f"{len(tokenized_sequences[0])} != {token_probs.shape[0]}"
198+
#print('token_probs.shape:', token_probs.shape)
199+
169200
for i_s, tokenized_seq in enumerate(tokenized_sequences):
170201
for i_aa, aa in enumerate(tokenized_seq):
171202
# alternative: use Tensor.index_select() function
@@ -417,7 +448,7 @@ def esm_train(
417448
def esm_setup(sequences, device: str | None = None, verbose: bool = True):
418449
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
419450
esm_base_model = esm_base_model.to(device)
420-
x_esm, esm_attention_mask = esm_tokenize_sequences(
451+
x_esm, esm_attention_mask = tokenize_sequences(
421452
sequences, esm_tokenizer, max_length=len(sequences[0]), verbose=verbose)
422453
llm_dict_esm = {
423454
'esm1v': {

pypef/plm/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pypef.utils.helpers import get_device
1010
from pypef.plm.utils import get_batches
11-
from pypef.plm.esm_lora_tune import esm_infer, esm_setup, esm_tokenize_sequences
11+
from pypef.plm.esm_lora_tune import esm_infer, esm_setup, tokenize_sequences
1212
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences, prosst_infer
1313

1414
import logging
@@ -21,7 +21,7 @@ def llm_tokenizer(llm_dict, seqs, verbose=True):
2121
except ValueError:
2222
raise SystemError("Unequal input sequence length detected!")
2323
if list(llm_dict.keys())[0] == 'esm1v':
24-
x_llm_seqs, _attention_mask = esm_tokenize_sequences(
24+
x_llm_seqs, _attention_mask = tokenize_sequences(
2525
seqs, tokenizer=llm_dict['esm1v']['llm_tokenizer'],
2626
max_length=len(seqs[0]), verbose=verbose
2727
)

pypef/plm/prosst_lora_tune.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,20 @@
3030

3131

3232
def prosst_tokenize_sequences(sequences, vocab, verbose=True):
33+
print(vocab)
3334
sequences = np.atleast_1d(sequences).tolist()
3435
x_sequences = []
3536
for sequence in tqdm(
3637
sequences, desc='Tokenizing sequences for ProSST modeling',
3738
disable=not verbose
3839
):
39-
x_sequence = []
40+
x_sequence = [vocab['<cls>']]
4041
for aa in sequence:
41-
x_sequence.append(vocab[aa])
42+
try:
43+
x_sequence.append(vocab[aa])
44+
except KeyError:
45+
x_sequence.append(vocab['<unk>'])
46+
x_sequence.append(vocab['<eos>'])
4247
x_sequences.append(x_sequence)
4348
return torch.Tensor(x_sequences).to(torch.int)
4449

scripts/ProteinGym_runs/official/benchmark_runs/pgym_cv_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pypef.utils.variant_data import get_mismatches
2020
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences
21-
from pypef.plm.esm_lora_tune import esm_setup, esm_tokenize_sequences
21+
from pypef.plm.esm_lora_tune import esm_setup, tokenize_sequences
2222
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
2323
from pypef.hybrid.hybrid_model import DCALLMHybridModel
2424

@@ -182,7 +182,7 @@ def main(cfg: DictConfig) -> None:
182182
elif llm == "esm1v":
183183
llm_kwargs = esm_setup(sequences=s_train)
184184
tokenizer = llm_kwargs['esm1v']['llm_tokenizer']
185-
x_llm_test, _attn_masks = esm_tokenize_sequences(
185+
x_llm_test, _attn_masks = tokenize_sequences(
186186
sequences=s_test, tokenizer=tokenizer, max_length=len(s_test[0])
187187
)
188188

scripts/ProteinGym_runs/protgym_hybrid_perf_test_crossval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pypef.dca.gremlin_inference import GREMLIN
2323
from pypef.plm.utils import get_batches, corr_loss
2424
from pypef.plm.esm_lora_tune import (
25-
get_esm_models, esm_tokenize_sequences,
25+
get_esm_models, tokenize_sequences,
2626
esm_train, esm_infer
2727
)
2828
from pypef.plm.prosst_lora_tune import (
@@ -151,7 +151,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
151151
dca_unopt_perf = spearmanr(fitnesses, y_pred_dca)[0]
152152
# ESM unsupervised
153153
try:
154-
x_esm, esm_attention_mask = esm_tokenize_sequences(
154+
x_esm, esm_attention_mask = tokenize_sequences(
155155
sequences, esm_tokenizer, max_length=len(wt_seq), verbose=False
156156
)
157157
y_esm = inference(sequences, 'esm', model=esm_base_model, verbose=False)

scripts/ProteinGym_runs/protgym_hybrid_perf_test_low_n.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pypef.dca.gremlin_inference import GREMLIN
2525
from pypef.plm.utils import get_batches
2626
from pypef.plm.esm_lora_tune import (
27-
get_esm_models, esm_tokenize_sequences,
27+
get_esm_models, tokenize_sequences,
2828
esm_train, esm_infer, corr_loss
2929
)
3030
from pypef.plm.prosst_lora_tune import (
@@ -143,7 +143,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
143143
dca_unopt_perf = spearmanr(fitnesses, y_pred_dca)[0]
144144

145145
try:
146-
x_esm, esm_attention_mask = esm_tokenize_sequences(
146+
x_esm, esm_attention_mask = tokenize_sequences(
147147
sequences, esm_tokenizer, max_length=len(wt_seq))
148148
y_esm = esm_infer(
149149
get_batches(x_esm, dtype=float, batch_size=1),

tests/test_api_functions.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pypef.plm.inference import inference, llm_tokenizer
2323
from pypef.hybrid.hybrid_model import DCALLMHybridModel
2424
from pypef.plm.esm_lora_tune import (
25-
get_esm_models, esm_tokenize_sequences,
25+
get_esm_models, tokenize_sequences,
2626
)
2727
from pypef.plm.prosst_lora_tune import (
2828
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied,
@@ -266,10 +266,10 @@ def test_plm_corr_blat_ecolx():
266266
for x in ['facebook/esm1v_t33_650M_UR90S_3']:
267267
esm_base_model, _esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models(model=x)
268268
esm_base_model = esm_base_model.to(device)
269-
x_esm, esm_attention_mask = esm_tokenize_sequences(
269+
x_esm, esm_attention_mask = tokenize_sequences(
270270
sequences, esm_tokenizer, max_length=len(blat_ecolx_wt_seq) + 2)
271271
# Tokenize WT sequence once
272-
wt_tokens, _ = esm_tokenize_sequences(
272+
wt_tokens, _ = tokenize_sequences(
273273
[blat_ecolx_wt_seq],
274274
esm_tokenizer,
275275
max_length=len(blat_ecolx_wt_seq) + 2
@@ -316,21 +316,18 @@ def test_plm_corr_blat_ecolx():
316316
#)
317317
#print(f'{x}: ESM1v (unsupervised performance): '
318318
# f'{spearmanr(y_true, y_esm.cpu())[0]}')
319-
#np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6360209552304472, decimal=6)
319+
#np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.666666666666666, decimal=6)
320320

321321
wt_input_ids, prosst_attention_mask, wt_structure_input_ids = get_structure_quantizied(
322322
pdb_blat_ecolx, prosst_tokenizer, blat_ecolx_wt_seq)
323-
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
324-
#y_prosst = get_logits_from_full_seqs(
325-
# x_prosst, prosst_base_model, wt_input_ids, prosst_attention_mask,
326-
# wt_structure_input_ids, train=False, verbose=True
327-
#)
328-
#print(f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
329-
# f'{spearmanr(y_true, y_prosst.cpu())[0]:.3f}')
330-
print('wt_input_ids:',wt_input_ids)
331-
print()
332-
print('wt_structure_input_ids:', wt_structure_input_ids)
333-
print()
323+
x_prosst = tokenize_sequences(sequences=sequences, tokenizer=prosst_tokenizer)
324+
y_prosst = get_logits_from_full_seqs(
325+
x_prosst, prosst_base_model, wt_input_ids, prosst_attention_mask,
326+
wt_structure_input_ids, train=False, verbose=True
327+
)
328+
print(f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
329+
f'{spearmanr(y_true, y_prosst.cpu())[0]:.3f}')
330+
334331
y_prosst = esm_infer_pll(
335332
xs=x_prosst,
336333
wt_input_ids=(wt_input_ids, wt_structure_input_ids), ## TODO
@@ -342,6 +339,8 @@ def test_plm_corr_blat_ecolx():
342339
train=False,
343340
verbose=True
344341
)
342+
print(f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
343+
f'{spearmanr(y_true, y_prosst.cpu())[0]:.3f}')
345344
# ACTUAL OLD VERSION: 0.743
346345

347346

0 commit comments

Comments
 (0)