Skip to content

Commit 77d2bbf

Browse files
committed
Update according to token seq max_length+=2
1 parent 7c3c2f9 commit 77d2bbf

File tree

5 files changed

+71
-36
lines changed

5 files changed

+71
-36
lines changed

pypef/plm/esm_lora_tune.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,16 @@ def esm_train(
191191
model.train(False)
192192

193193

194-
def esm_setup(sequences, device: str | None = None, verbose: bool = True):
194+
def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True):
195195
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
196196
esm_base_model = esm_base_model.to(device)
197+
wt_tokens, _ = tokenize_sequences(
198+
[wt_seq],
199+
esm_tokenizer,
200+
max_length=len(wt_seq) + 2
201+
)
197202
x_esm, esm_attention_mask = tokenize_sequences(
198-
sequences, esm_tokenizer, max_length=len(sequences[0]), verbose=verbose)
203+
sequences, esm_tokenizer, max_length=len(wt_seq) + 2, verbose=verbose)
199204
llm_dict_esm = {
200205
'esm1v': {
201206
'llm_base_model': esm_base_model,
@@ -205,6 +210,7 @@ def esm_setup(sequences, device: str | None = None, verbose: bool = True):
205210
'llm_inference_function': esm_infer,
206211
'llm_loss_function': corr_loss,
207212
'x_llm' : x_esm,
213+
'input_ids': wt_tokens,
208214
'llm_attention_mask': esm_attention_mask,
209215
'llm_tokenizer': esm_tokenizer
210216
}

pypef/plm/inference.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,12 @@ def llm_tokenizer(llm_dict, seqs, verbose=True):
349349
if list(llm_dict.keys())[0] == 'esm1v':
350350
x_llm_seqs, _attention_mask = tokenize_sequences(
351351
seqs, tokenizer=llm_dict['esm1v']['llm_tokenizer'],
352-
max_length=len(seqs[0]), verbose=verbose
352+
max_length=len(seqs[0]) + 2, verbose=verbose
353353
)
354354
elif list(llm_dict.keys())[0] == 'prosst':
355-
x_llm_seqs = prosst_simple_vocab_aa_tokenizer(
356-
seqs, vocab=llm_dict['prosst']['llm_vocab'], verbose=verbose
355+
x_llm_seqs, _attention_mask = tokenize_sequences(
356+
seqs, tokenizer=llm_dict['prosst']['llm_tokenizer'],
357+
max_length=len(seqs[0]) + 2, verbose=verbose
357358
)
358359
else:
359360
raise SystemError(f"Unknown LLM dictionary input:\n{list(llm_dict.keys())[0]}")
@@ -376,17 +377,29 @@ def inference(
376377
device = get_device()
377378
if llm == 'esm':
378379
logger.info("Zero-shot LLM inference on test set using ESM1v...")
379-
llm_dict = esm_setup(sequences, verbose=verbose)
380+
llm_dict = esm_setup(wt_seq, sequences, verbose=verbose)
380381
if model is None:
381382
model = llm_dict['esm1v']['llm_base_model']
382383
x_llm_test = llm_tokenizer(llm_dict, sequences, verbose)
383384
y_test_pred = esm_infer(#llm_dict['esm1v']['llm_inference_function'](
384-
xs=torch.tensor(get_batches(x_llm_test, batch_size=1, dtype=int)),
385+
xs=torch.from_numpy(get_batches(x_llm_test, batch_size=1, dtype=int)),
385386
attention_mask=llm_dict['esm1v']['llm_attention_mask'],
386387
model=model,
387388
device=device,
388389
verbose=verbose
389390
).cpu()
391+
y_test_pred = plm_inference(
392+
xs=x_llm_test,
393+
wt_input_ids=torch.tensor(llm_dict['esm1v']['input_ids'][0], dtype=torch.long),
394+
attention_mask=llm_dict['esm1v']['llm_attention_mask'],
395+
model=model,
396+
mask_token_id=llm_dict['esm1v']['llm_tokenizer'].mask_token_id,
397+
inference_type='unmasked',
398+
batch_size=5,
399+
train=False,
400+
verbose=True
401+
).cpu()
402+
390403
elif llm == 'prosst':
391404
logger.info("Zero-shot LLM inference on test set using ProSST...")
392405
llm_dict = prosst_setup(
@@ -395,14 +408,27 @@ def inference(
395408
if model is None:
396409
model = llm_dict['prosst']['llm_base_model']
397410
x_llm_test = llm_tokenizer(llm_dict, sequences, verbose)
398-
y_test_pred = prosst_infer(#llm_dict['prosst']['llm_inference_function'](
399-
xs=x_llm_test,
400-
model=model,
401-
input_ids=llm_dict['prosst']['input_ids'],
402-
attention_mask=llm_dict['prosst']['llm_attention_mask'],
403-
structure_input_ids=llm_dict['prosst']['structure_input_ids'],
404-
verbose=verbose,
405-
device=device
411+
#y_test_pred = prosst_infer(#llm_dict['prosst']['llm_inference_function'](
412+
# xs=x_llm_test,
413+
# model=model,
414+
# input_ids=llm_dict['prosst']['input_ids'],
415+
# attention_mask=llm_dict['prosst']['llm_attention_mask'],
416+
# structure_input_ids=llm_dict['prosst']['structure_input_ids'],
417+
# verbose=verbose,
418+
# device=device
419+
#).cpu()
420+
print('XXX:', np.shape(x_llm_test))
421+
y_test_pred = plm_inference(
422+
xs=x_llm_test,
423+
wt_input_ids=llm_dict['prosst']['input_ids'],
424+
attention_mask=llm_dict['prosst']['llm_attention_mask'],
425+
model=model,
426+
mask_token_id=llm_dict['prosst']['llm_tokenizer'].mask_token_id,
427+
inference_type='mutation-masking',
428+
wt_structure_input_ids=llm_dict['prosst']['structure_input_ids'],
429+
batch_size=5,
430+
train=False,
431+
verbose=True
406432
).cpu()
407433
else:
408434
raise RuntimeError("Unknown LLM option.")

pypef/plm/prosst_lora_tune.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pypef.plm.utils import corr_loss
2727
from pypef.plm.prosst_structure.quantizer import PdbQuantizer
2828
from pypef.utils.helpers import get_device
29+
from pypef.plm.esm_lora_tune import tokenize_sequences
2930
from pypef.plm.utils import load_model_and_tokenizer
3031

3132

@@ -37,14 +38,13 @@ def prosst_simple_vocab_aa_tokenizer(sequences, vocab, verbose=True):
3738
sequences, desc='Tokenizing sequences for ProSST modeling',
3839
disable=not verbose
3940
):
40-
#x_sequence = [vocab['<cls>']]
41-
x_sequence = []
41+
x_sequence = [vocab['<cls>']]
4242
for aa in sequence:
4343
try:
4444
x_sequence.append(vocab[aa])
4545
except KeyError:
4646
x_sequence.append(vocab['<unk>'])
47-
#x_sequence.append(vocab['<eos>'])
47+
x_sequence.append(vocab['<eos>'])
4848
x_sequences.append(x_sequence)
4949
return torch.Tensor(x_sequences).to(torch.int)
5050

@@ -80,14 +80,15 @@ def get_logits_from_full_seqs(
8080
ss_input_ids=structure_input_ids
8181
)
8282
logits = torch.log_softmax(outputs.logits[:, 1:-1], dim=-1).squeeze()
83-
for i_s, sequence in enumerate(
83+
for i_s, x_sequence in enumerate(
8484
tqdm(
8585
xs,
8686
desc=f'ProSST inference: getting sequence logits ({device.upper()})',
8787
disable=not verbose
8888
)
8989
):
90-
for i_aa, x_aa in enumerate(sequence):
90+
x_sequence = x_sequence[1:-1] # if cls, eos tokens included
91+
for i_aa, x_aa in enumerate(x_sequence):
9192
if i_aa == 0:
9293
seq_log_probs = logits[i_aa, x_aa].reshape(1)
9394
else:
@@ -297,8 +298,9 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose
297298
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
298299
pdb_file, prosst_tokenizer, wt_seq, verbose=verbose
299300
)
300-
x_llm_train_prosst = prosst_simple_vocab_aa_tokenizer(
301-
sequences=sequences, vocab=prosst_vocab, verbose=verbose
301+
x_llm_train_prosst, _attention_mask = tokenize_sequences(
302+
sequences=sequences, tokenizer=prosst_tokenizer,
303+
max_length=len(wt_seq) + 2, verbose=verbose
302304
)
303305
llm_dict_prosst = {
304306
'prosst': {

pypef/plm/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def corr_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
2727
def get_batches(a, dtype, batch_size=5,
2828
keep_remaining=False, verbose: bool = False):
2929
a = np.asarray(a, dtype=dtype)
30+
a_remaining = None
3031
orig_shape = np.shape(a)
3132
remaining = len(a) % batch_size
3233
if remaining != 0:
@@ -46,12 +47,13 @@ def get_batches(a, dtype, batch_size=5,
4647
if verbose:
4748
print(f'{orig_shape} -> {new_shape} (dropped {remaining})')
4849
if keep_remaining:
49-
print(f'Appending remaining to collected batches as last batch '
50-
f'(the resulting inhomogenous list shape is '
51-
f'{np.shape(a)} + {np.shape(a_remaining)} = ('
52-
f'{np.shape(a)[0] + 1}, *, {np.shape(a)[-1]}))...')
53-
a = a.tolist()
54-
a.append(a_remaining)
50+
if a_remaining is not None:
51+
print(f'Appending remaining to collected batches as last batch '
52+
f'(the resulting inhomogenous list shape is '
53+
f'{np.shape(a)} + {np.shape(a_remaining)} = ('
54+
f'{np.shape(a)[0] + 1}, *, {np.shape(a)[-1]}))...')
55+
a = a.tolist()
56+
a.append(a_remaining)
5557
return a
5658

5759

tests/test_api_functions.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434

3535
torch.manual_seed(42)
36+
# torch.use_deterministic_algorithms(True)
3637
np.random.seed(42)
3738

3839
msa_file_avgfp = os.path.abspath(os.path.join(
@@ -111,29 +112,28 @@ def test_hybrid_model_dca_llm():
111112
decimal=7
112113
)
113114
assert len(train_seqs_aneh[0]) == len(g.wt_seq)
114-
115-
y_pred_esm = inference(train_seqs_aneh, 'esm')
115+
aneh_wt_seq = get_wt_sequence(wt_seq_file_aneh)
116+
y_pred_esm = inference(train_seqs_aneh, 'esm', wt_seq=aneh_wt_seq)
116117
np.testing.assert_almost_equal(
117118
spearmanr(train_ys_aneh, y_pred_esm)[0],
118-
-0.21073416060442696,
119+
-0.713214007088901,
119120
decimal=7
120121
)
121-
aneh_wt_seq = get_wt_sequence(wt_seq_file_aneh)
122122
y_pred_prosst = inference(
123123
train_seqs_aneh, 'prosst',
124124
pdb_file=pdb_file_aneh, wt_seq=aneh_wt_seq
125125
)
126126
np.testing.assert_almost_equal(
127127
spearmanr(train_ys_aneh, y_pred_prosst)[0],
128-
-0.7425657069861902,
128+
-0.7394433335146882,
129129
decimal=7
130130
)
131131

132132
x_dca_test = g.get_scores(test_seqs_aneh, encode=True)
133133
for i, setup in enumerate([esm_setup, prosst_setup]):
134134
print(['~~~ ESM ~~~', '~~~ ProSST ~~~'][i])
135135
if setup == esm_setup:
136-
llm_dict = setup(sequences=train_seqs_aneh)
136+
llm_dict = setup(sequences=train_seqs_aneh, wt_seq=aneh_wt_seq)
137137
else: # elif setup == prosst_setup:
138138
llm_dict = setup(
139139
aneh_wt_seq, pdb_file_aneh, sequences=train_seqs_aneh)
@@ -163,7 +163,7 @@ def test_hybrid_model_dca_llm():
163163
)
164164
np.testing.assert_almost_equal(
165165
spearmanr(hm.y_ttest, hm.y_llm_ttest)[0],
166-
[-0.21761360470606333, -0.8330644449247571][i],
166+
[-0.17231040881725562, -0.8330644449247571][i],
167167
decimal=7
168168
)
169169
# Nondeterministic behavior (without setting seed), should be about ~0.7 to ~0.9,
@@ -316,7 +316,6 @@ def test_plm_corr_blat_ecolx():
316316
#print(f'{x}: ESM1v (unsupervised performance): '
317317
# f'{spearmanr(y_true, y_esm.cpu())[0]}')
318318
#np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.666666666666666, decimal=6)
319-
print(prosst_vocab)
320319
wt_input_ids, prosst_attention_mask, wt_structure_input_ids = get_structure_quantizied(
321320
pdb_blat_ecolx, prosst_tokenizer, blat_ecolx_wt_seq)
322321
x_prosst2 = prosst_simple_vocab_aa_tokenizer(sequences, prosst_vocab)

0 commit comments

Comments
 (0)