Skip to content

Commit a7d2040

Browse files
committed
Add kwargs for plm inference functions
1 parent 77d2bbf commit a7d2040

File tree

2 files changed

+38
-40
lines changed

2 files changed

+38
-40
lines changed

pypef/plm/inference.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,43 +39,31 @@ def unmasked_wt_score(
3939
train: bool = False,
4040
cut_special_tokens: bool = True, # assumption: cut first and last token
4141
device=None,
42-
**kwargs
42+
verbose: bool = False,
43+
**model_kwargs
4344
):
4445
if device is None:
4546
device = get_device()
4647
if wt_input_ids.dim() == 1:
4748
wt_input_ids = wt_input_ids.unsqueeze(0)
48-
structure_input_ids = kwargs.get("structure_input_ids", None)
49+
#structure_input_ids = model_kwargs.get("structure_input_ids", None)
4950

5051
attention_masks = torch.Tensor(np.full(
5152
shape=np.shape(wt_input_ids), fill_value=attention_mask)).to(torch.int64)
5253
if train:
53-
if structure_input_ids is not None:
54+
outputs = model(
55+
input_ids=wt_input_ids.to(device),
56+
attention_mask=attention_masks.to(device),
57+
**model_kwargs
58+
)
59+
60+
else:
61+
with torch.no_grad():
5462
outputs = model(
5563
input_ids=wt_input_ids.to(device),
5664
attention_mask=attention_masks.to(device),
57-
ss_input_ids=structure_input_ids.to(device)
58-
)
59-
else:
60-
outputs = model(
61-
wt_input_ids.to(device),
62-
attention_masks.to(device),
63-
output_hidden_states=False
65+
**model_kwargs
6466
)
65-
else:
66-
with torch.no_grad():
67-
if structure_input_ids is not None:
68-
outputs = model(
69-
input_ids=wt_input_ids.to(device),
70-
attention_mask=attention_masks.to(device),
71-
ss_input_ids=structure_input_ids.to(device)
72-
)
73-
else:
74-
outputs = model(
75-
wt_input_ids.to(device),
76-
attention_masks.to(device),
77-
output_hidden_states=False,
78-
)
7967

8068
logits = outputs.logits
8169
logits = logits.squeeze(0) # remove batch dim
@@ -105,7 +93,7 @@ def unmasked_wt_score(
10593
return log_probs
10694

10795

108-
def esm_mutation_only_mutation_masked_pll(
96+
def mutation_only_mutation_masked_pll(
10997
tokenized_sequences: torch.Tensor, # (L,)
11098
wt_input_ids: torch.Tensor, # (L,)
11199
attention_mask: torch.Tensor, # (L,)
@@ -198,7 +186,7 @@ def esm_mutation_only_mutation_masked_pll(
198186
return plls
199187

200188

201-
def esm_mutation_all_pos_masked_pll(
189+
def mutation_all_pos_masked_pll(
202190
tokenized_sequences: torch.Tensor, # (L,)
203191
attention_mask: torch.Tensor, # (L,)
204192
model,
@@ -285,7 +273,7 @@ def plm_inference(
285273
wt_input_ids,
286274
attention_mask,
287275
model,
288-
mask_token_id,
276+
mask_token_id = None,
289277
inference_type='unmasked',
290278
wt_structure_input_ids=None,
291279
batch_size=5,
@@ -304,9 +292,9 @@ def plm_inference(
304292
if not isinstance(attention_mask, torch.Tensor):
305293
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
306294
if inference_type == 'mutation-masking':
307-
inference_function = esm_mutation_only_mutation_masked_pll
295+
inference_function = mutation_only_mutation_masked_pll
308296
elif inference_type in ['full-masking', 'all-pos-masking']:
309-
inference_function = esm_mutation_all_pos_masked_pll
297+
inference_function = mutation_all_pos_masked_pll
310298
elif inference_type in ['unmasked', 'wt-marginals']:
311299
inference_function = unmasked_wt_score
312300
else:
@@ -317,6 +305,13 @@ def plm_inference(
317305
xs_b = get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True)
318306
desc = f"Inference: {inference_type} batch (size={batch_size}) processing ({device.upper()})'"
319307

308+
kwargs = {}
309+
if mask_token_id is not None:
310+
kwargs["mask_token_id"] = mask_token_id
311+
312+
if wt_structure_input_ids is not None:
313+
kwargs["structure_input_ids"] = wt_structure_input_ids
314+
320315
pbar = tqdm(
321316
range(len(xs_b)),
322317
desc=desc,
@@ -327,13 +322,12 @@ def plm_inference(
327322
pll = inference_function(
328323
tokenized_sequences=torch.tensor(xs_b[i]),
329324
wt_input_ids=wt_input_ids,
330-
structure_input_ids=wt_structure_input_ids,
331325
attention_mask=attention_mask,
332326
model=model,
333-
mask_token_id=mask_token_id,
334327
train=train,
335328
device=device,
336-
verbose=False
329+
verbose=False,
330+
**kwargs
337331
)
338332
scores.append(pll)
339333
return torch.cat(scores)

tests/test_api_functions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272

7373

7474
def test_gremlin_avgfp():
75+
print("test_gremlin_avgfp()...")
7576
g = GREMLIN(
7677
alignment=msa_file_avgfp,
7778
char_alphabet="ARNDCQEGHILKMFPSTWYV-",
@@ -87,6 +88,7 @@ def test_gremlin_avgfp():
8788

8889

8990
def test_hybrid_model_dca_llm():
91+
print("test_hybrid_model_dca_llm()...")
9092
g = GREMLIN(
9193
alignment=msa_file_aneh,
9294
char_alphabet="ARNDCQEGHILKMFPSTWYV-",
@@ -225,6 +227,7 @@ def test_hybrid_model_dca_llm():
225227

226228

227229
def test_dataset_b_results():
230+
print("test_dataset_b_results()...")
228231
aaindex = "WOLR810101.txt"
229232
x_fft_train, _ = AAIndexEncoding(
230233
full_aaidx_txt_path(aaindex), train_seqs_aneh
@@ -250,6 +253,7 @@ def test_dataset_b_results():
250253

251254
@pytest.mark.requires_gpu
252255
def test_plm_corr_blat_ecolx():
256+
print("test_plm_corr_blat_ecolx()...")
253257
device = get_device()
254258
print("Device", device)
255259
blat_ecolx_wt_seq = get_wt_sequence(wt_seq_file_blat_ecolx)
@@ -283,7 +287,7 @@ def test_plm_corr_blat_ecolx():
283287
train=False,
284288
verbose=True
285289
)
286-
print(f'{x}: ESM1v (unsupervised performance): '
290+
print(f'{x}: ESM1v (unsupervised performance mutation-masking): '
287291
f'{spearmanr(y_true, y_esm.cpu())[0]}')
288292
np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6367826285982324, decimal=6)
289293

@@ -292,13 +296,13 @@ def test_plm_corr_blat_ecolx():
292296
wt_input_ids=wt_tokens,
293297
attention_mask=esm_attention_mask,
294298
model=esm_base_model,
295-
mask_token_id=esm_tokenizer.mask_token_id,
299+
mask_token_id=None, # do not define for unmasked
296300
inference_type='unmasked',
297301
batch_size=5,
298302
train=False,
299303
verbose=True
300304
)
301-
print(f'{x}: ESM1v (unsupervised performance): '
305+
print(f'{x}: ESM1v (unsupervised performance unmasked): '
302306
f'{spearmanr(y_true, y_esm.cpu())[0]}')
303307
np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.6498987261125897, decimal=6)
304308

@@ -341,7 +345,7 @@ def test_plm_corr_blat_ecolx():
341345
train=False,
342346
verbose=True
343347
)
344-
print(f'ProSST (unsupervised performance): ' # ProSST not made/trained for this: 0.607137337377509
348+
print(f'ProSST (unsupervised performance): ' # ProSST not made/trained for MLM: 0.607137337377509
345349
f'{spearmanr(y_true, y_prosst.cpu())[0]}')
346350
np.testing.assert_almost_equal(spearmanr(y_true, y_prosst.cpu())[0], 0.607137337377509, decimal=6)
347351

@@ -351,7 +355,7 @@ def test_plm_corr_blat_ecolx():
351355
wt_input_ids=wt_input_ids,
352356
attention_mask=prosst_attention_mask,
353357
model=prosst_base_model,
354-
mask_token_id=prosst_tokenizer.mask_token_id,
358+
mask_token_id=None, # do not define for unmasked
355359
inference_type='unmasked',
356360
wt_structure_input_ids=wt_structure_input_ids,
357361
batch_size=5,
@@ -379,8 +383,8 @@ def test_plm_corr_blat_ecolx():
379383

380384

381385
if __name__ == "__main__":
382-
test_gremlin_avgfp()
383-
test_hybrid_model_dca_llm()
384-
test_dataset_b_results()
386+
#test_gremlin_avgfp()
387+
#test_hybrid_model_dca_llm()
388+
#test_dataset_b_results()
385389
test_plm_corr_blat_ecolx()
386390

0 commit comments

Comments
 (0)