Skip to content

Commit 3ef22b3

Browse files
committed
Full unmasked (WT-relative scoring)
1 parent 7c0243f commit 3ef22b3

File tree

3 files changed

+61
-159
lines changed

3 files changed

+61
-159
lines changed

pypef/plm/esm_lora_tune.py

Lines changed: 36 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from __future__ import annotations
1818

1919
import logging
20-
from time import sleep
20+
21+
from pypef.plm.prosst_lora_tune import get_logits_from_full_seqs
2122
logger = logging.getLogger('pypef.llm.esm_lora_tune')
2223

2324
import torch
@@ -142,35 +143,37 @@ def esm_infer(xs, attention_mask, model, device: str | None = None, verbose=Fals
142143
return torch.flatten(y_preds_total)
143144

144145

145-
def esm_unmasked_reconstruction_score(
146+
def esm_unmasked_wt_score(
146147
tokenized_sequences,
147148
attention_mask,
149+
wt_input_ids,
148150
model,
149151
train: bool = False,
150152
device=None,
151-
**kws
153+
**kwargs
152154
):
153155
if device is None:
154156
device = get_device()
157+
wt_input_ids = wt_input_ids.unsqueeze(0)
155158
attention_masks = torch.Tensor(np.full(
156-
shape=np.shape(tokenized_sequences), fill_value=attention_mask)).to(torch.int64)
159+
shape=np.shape(wt_input_ids), fill_value=attention_mask)).to(torch.int64)
157160
if train:
158-
with torch.no_grad():
159-
outputs = model(tokenized_sequences.to(device), attention_masks.to(device),
160-
output_hidden_states=False)
161+
outputs = model(wt_input_ids.to(device), attention_masks.to(device),
162+
output_hidden_states=False)
161163
else:
162-
outputs = model(tokenized_sequences.to(device), attention_masks.to(device),
164+
with torch.no_grad():
165+
outputs = model(wt_input_ids.to(device), attention_masks.to(device),
163166
output_hidden_states=False)
164167
logits = outputs.logits
165-
token_probs = torch.log_softmax(logits, dim=-1)
166-
for i_s, sequence in enumerate(tokenized_sequences):
167-
for i_aa, aa in enumerate(sequence):
168+
token_probs = torch.log_softmax(logits, dim=-1).squeeze(0)
169+
for i_s, tokenized_seq in enumerate(tokenized_sequences):
170+
for i_aa, aa in enumerate(tokenized_seq):
168171
# alternative: use Tensor.index_select() function
169172
if i_aa == 0:
170-
seq_log_probs = token_probs[i_s, i_aa, aa].reshape(1)
173+
seq_log_probs = token_probs[i_aa, aa].reshape(1)
171174
else:
172175
seq_log_probs = torch.cat(
173-
(seq_log_probs, token_probs[i_s, i_aa, aa].reshape(1)), 0)
176+
(seq_log_probs, token_probs[i_aa, aa].reshape(1)), 0)
174177
if i_s == 0:
175178
log_probs = torch.sum(torch.Tensor(seq_log_probs)).reshape(1)
176179
else:
@@ -179,124 +182,6 @@ def esm_unmasked_reconstruction_score(
179182
return log_probs
180183

181184

182-
def esm_masked_pll(
183-
input_ids: torch.Tensor, # (B, L)
184-
attention_mask: torch.Tensor, # (B, L)
185-
model,
186-
mask_token_id: int,
187-
device: str | None = None,
188-
verbose: bool = False,
189-
):
190-
"""
191-
Compute true pseudo-log-likelihood (PLL) for an MLM (ESM).
192-
193-
Returns:
194-
pll_scores: torch.Tensor of shape (B,)
195-
"""
196-
if device is None:
197-
device = next(model.parameters()).device
198-
199-
input_ids = input_ids.to(device)
200-
attention_mask = attention_mask.to(device)
201-
202-
B, L = input_ids.shape
203-
pll_scores = torch.zeros(B, device=device)
204-
205-
model.eval()
206-
207-
for pos in tqdm(
208-
range(L),
209-
desc="ESM masked PLL",
210-
disable=not verbose
211-
):
212-
# Skip padding positions (position padding for all sequences in the batch)
213-
if attention_mask[:, pos].sum() == 0:
214-
continue
215-
216-
# Clone and mask position `pos`
217-
masked_input_ids = input_ids.clone()
218-
masked_input_ids[:, pos] = mask_token_id
219-
220-
with torch.no_grad():
221-
outputs = model(
222-
input_ids=masked_input_ids,
223-
attention_mask=attention_mask,
224-
)
225-
226-
logits = outputs.logits # (B, L, V)
227-
228-
# Log-probabilities at masked position
229-
log_probs = F.log_softmax(logits[:, pos, :], dim=-1)
230-
231-
# True tokens at this position
232-
true_tokens = input_ids[:, pos]
233-
234-
# Gather log-prob of the true token
235-
token_log_probs = log_probs.gather(
236-
dim=1,
237-
index=true_tokens.unsqueeze(1)
238-
).squeeze(1)
239-
240-
# Only count non-padding
241-
pll_scores += token_log_probs * attention_mask[:, pos]
242-
243-
return pll_scores
244-
245-
246-
def esm_infer_masked_pll(
247-
xs,
248-
attention_mask,
249-
model,
250-
mask_token_id,
251-
batch_size: int = 4,
252-
device: str | None = None,
253-
verbose: bool = False,
254-
):
255-
if device is None:
256-
device = get_device()
257-
258-
model = model.to(device)
259-
model.eval()
260-
261-
if not isinstance(xs, torch.Tensor):
262-
xs = torch.tensor(xs, dtype=torch.long)
263-
264-
if not isinstance(attention_mask, torch.Tensor):
265-
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
266-
267-
xs = xs.to(device)
268-
269-
# Expand mask to (N, L) if needed
270-
if attention_mask.dim() == 1:
271-
attention_mask = attention_mask.unsqueeze(0).expand(xs.shape[0], -1)
272-
273-
attention_mask = attention_mask.to(device)
274-
275-
pll_all = []
276-
277-
for i in tqdm(
278-
range(0, xs.shape[0], batch_size),
279-
desc="ESM PLL inference",
280-
disable=not verbose,
281-
):
282-
xs_b = xs[i:i + batch_size]
283-
am_b = attention_mask[i:i + batch_size]
284-
285-
pll_b = esm_masked_pll(
286-
input_ids=xs_b,
287-
attention_mask=am_b,
288-
model=model,
289-
mask_token_id=mask_token_id,
290-
device=device,
291-
verbose=False,
292-
)
293-
294-
pll_all.append(pll_b.cpu())
295-
296-
return torch.cat(pll_all)
297-
298-
299-
300185
def esm_mutation_only_mutation_masked_pll(
301186
tokenized_sequences: torch.Tensor, # (L,)
302187
wt_input_ids: torch.Tensor, # (L,)
@@ -306,6 +191,7 @@ def esm_mutation_only_mutation_masked_pll(
306191
train: bool = False,
307192
device: str | None = None,
308193
verbose: bool = False,
194+
**kwargs
309195
):
310196
"""
311197
Correct mutation-only pseudo-log-likelihood for ONE sequence.
@@ -335,16 +221,16 @@ def esm_mutation_only_mutation_masked_pll(
335221
masked_input_ids = tokenized_seq.clone()
336222
masked_input_ids[pos] = mask_token_id
337223
if train:
224+
outputs = model(
225+
input_ids=masked_input_ids.unsqueeze(0),
226+
attention_mask=attention_mask.unsqueeze(0),
227+
)
228+
else:
338229
with torch.no_grad():
339230
outputs = model(
340231
input_ids=masked_input_ids.unsqueeze(0),
341232
attention_mask=attention_mask.unsqueeze(0),
342233
)
343-
else:
344-
outputs = model(
345-
input_ids=masked_input_ids.unsqueeze(0),
346-
attention_mask=attention_mask.unsqueeze(0),
347-
)
348234
logits = outputs.logits # (1, L, V)
349235

350236
log_probs = F.log_softmax(logits[0, pos], dim=-1)
@@ -393,16 +279,16 @@ def esm_mutation_all_pos_masked_pll(
393279
masked_input_ids[pos] = mask_token_id
394280

395281
if train:
282+
outputs = model(
283+
input_ids=masked_input_ids.unsqueeze(0),
284+
attention_mask=attention_mask.unsqueeze(0),
285+
)
286+
else:
396287
with torch.no_grad():
397288
outputs = model(
398289
input_ids=masked_input_ids.unsqueeze(0),
399290
attention_mask=attention_mask.unsqueeze(0),
400291
)
401-
else:
402-
outputs = model(
403-
input_ids=masked_input_ids.unsqueeze(0),
404-
attention_mask=attention_mask.unsqueeze(0),
405-
)
406292
logits = outputs.logits # (1, L, V)
407293

408294
log_probs = F.log_softmax(logits[0, pos], dim=-1)
@@ -437,13 +323,16 @@ def esm_infer_pll(
437323

438324
if not isinstance(attention_mask, torch.Tensor):
439325
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
440-
441-
if inference_type == 'mutation_masking':
326+
wt_structure_input_ids = None
327+
if inference_type == 'mutation-masking':
442328
inference_function = esm_mutation_only_mutation_masked_pll
443-
elif inference_type == 'full_masking':
329+
elif inference_type in ['full-masking', 'all-pos-masking']:
444330
inference_function = esm_mutation_all_pos_masked_pll
445-
elif inference_type == 'unmasked':
446-
inference_function = esm_unmasked_reconstruction_score
331+
elif inference_type in ['unmasked', 'wt-marginals']:
332+
inference_function = esm_unmasked_wt_score
333+
elif inference_type == 'prosst':
334+
wt_input_ids, wt_structure_input_ids = wt_input_ids
335+
inference_function = esm_unmasked_wt_score
447336
else:
448337
raise SystemError("Choose between 'mutation_masking', 'unmasked', and 'full_masking'")
449338

@@ -462,6 +351,7 @@ def esm_infer_pll(
462351
pll = inference_function(
463352
tokenized_sequences=torch.tensor(xs_b[i]),
464353
wt_input_ids=wt_input_ids,
354+
structure_input_ids=wt_structure_input_ids,
465355
attention_mask=attention_mask,
466356
model=model,
467357
mask_token_id=mask_token_id,

pypef/plm/prosst_lora_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from Bio import SeqIO, BiopythonParserWarning
2424
warnings.filterwarnings(action='ignore', category=BiopythonParserWarning)
2525

26-
from pypef.plm.esm_lora_tune import corr_loss
26+
from pypef.plm.utils import corr_loss
2727
from pypef.plm.prosst_structure.quantizer import PdbQuantizer
2828
from pypef.utils.helpers import get_device
2929
from pypef.plm.utils import load_model_and_tokenizer

tests/test_api_functions.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ def test_plm_corr_blat_ecolx():
268268
esm_base_model = esm_base_model.to(device)
269269
x_esm, esm_attention_mask = esm_tokenize_sequences(
270270
sequences, esm_tokenizer, max_length=len(blat_ecolx_wt_seq) + 2)
271-
272271
# Tokenize WT sequence once
273272
wt_tokens, _ = esm_tokenize_sequences(
274273
[blat_ecolx_wt_seq],
@@ -282,15 +281,14 @@ def test_plm_corr_blat_ecolx():
282281
attention_mask=esm_attention_mask,
283282
model=esm_base_model,
284283
mask_token_id=esm_tokenizer.mask_token_id,
285-
inference_type='mutation_masking',
284+
inference_type='mutation-masking',
286285
batch_size=5,
287286
train=False,
288287
verbose=True
289288
)
290289
print(f'{x}: ESM1v (unsupervised performance): '
291290
f'{spearmanr(y_true, y_esm.cpu())[0]}')
292291
np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6367826285982324, decimal=6)
293-
294292
y_esm = esm_infer_pll(
295293
xs=x_esm,
296294
wt_input_ids=wt_tokens,
@@ -304,15 +302,14 @@ def test_plm_corr_blat_ecolx():
304302
)
305303
print(f'{x}: ESM1v (unsupervised performance): '
306304
f'{spearmanr(y_true, y_esm.cpu())[0]}')
307-
np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6381789551033011, decimal=6)
308-
305+
np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6498987261125897, decimal=6)
309306
#y_esm = esm_infer_pll(
310307
# xs=x_esm,
311308
# wt_input_ids=wt_tokens,
312309
# attention_mask=esm_attention_mask,
313310
# model=esm_base_model,
314311
# mask_token_id=esm_tokenizer.mask_token_id,
315-
# inference_type='full_masking',
312+
# inference_type='full-masking',
316313
# batch_size=5,
317314
# train=False,
318315
# verbose=True
@@ -321,15 +318,30 @@ def test_plm_corr_blat_ecolx():
321318
# f'{spearmanr(y_true, y_esm.cpu())[0]}')
322319
#np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6360209552304472, decimal=6)
323320

324-
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
321+
wt_input_ids, prosst_attention_mask, wt_structure_input_ids = get_structure_quantizied(
325322
pdb_blat_ecolx, prosst_tokenizer, blat_ecolx_wt_seq)
326323
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
327-
y_prosst = get_logits_from_full_seqs(
328-
x_prosst, prosst_base_model, input_ids, prosst_attention_mask,
329-
structure_input_ids, train=False, verbose=True
324+
#y_prosst = get_logits_from_full_seqs(
325+
# x_prosst, prosst_base_model, wt_input_ids, prosst_attention_mask,
326+
# wt_structure_input_ids, train=False, verbose=True
327+
#)
328+
#print(f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
329+
# f'{spearmanr(y_true, y_prosst.cpu())[0]:.3f}')
330+
print('wt_input_ids:',wt_input_ids)
331+
print()
332+
print('wt_structure_input_ids:', wt_structure_input_ids)
333+
print()
334+
y_prosst = esm_infer_pll(
335+
xs=x_prosst,
336+
wt_input_ids=(wt_input_ids, wt_structure_input_ids), ## TODO
337+
attention_mask=prosst_attention_mask,
338+
model=prosst_base_model,
339+
mask_token_id=prosst_tokenizer.mask_token_id,
340+
inference_type='prosst', ## TODO
341+
batch_size=5,
342+
train=False,
343+
verbose=True
330344
)
331-
print(f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
332-
f'{spearmanr(y_true, y_prosst.cpu())[0]:.3f}')
333345
# ACTUAL OLD VERSION: 0.743
334346

335347

0 commit comments

Comments
 (0)