Skip to content

Commit 7c3c2f9

Browse files
committed
small cleanup
1 parent 0e927dd commit 7c3c2f9

File tree

4 files changed

+337
-322
lines changed

4 files changed

+337
-322
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ def get_subsplits_train(self, train_size_fit: float = 0.66):
409409
def train_llm(self):
410410
# LoRA training on y_llm_ttrain --> Testing on y_llm_ttest
411411
x_llm_ttrain_b, scores_ttrain_b = (
412-
get_batches(self.x_llm_ttrain, batch_size=self.batch_size, dtype=int),
413-
get_batches(self.y_ttrain, batch_size=self.batch_size, dtype=float)
412+
torch.from_numpy(get_batches(self.x_llm_ttrain, batch_size=self.batch_size, dtype=int)),
413+
torch.from_numpy(get_batches(self.y_ttrain, batch_size=self.batch_size, dtype=float))
414414
)
415415

416416
if self.llm_key == 'prosst':
@@ -431,7 +431,7 @@ def train_llm(self):
431431
device=self.device
432432
)
433433
elif self.llm_key == 'esm1v':
434-
x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=1, dtype=int)
434+
x_llm_ttest_b = torch.from_numpy(get_batches(self.x_llm_ttest, batch_size=1, dtype=int))
435435
y_llm_ttest = self.llm_inference_function(
436436
xs=x_llm_ttest_b,
437437
model=self.llm_model,
@@ -633,7 +633,7 @@ def hybrid_prediction(
633633
verbose=verbose,
634634
device=self.device).detach().cpu().numpy()
635635
elif self.llm_key == 'esm1v':
636-
x_llm_b = get_batches(x_llm, batch_size=1, dtype=int)
636+
x_llm_b = torch.from_numpy(get_batches(x_llm, batch_size=1, dtype=int))
637637
y_llm = self.llm_inference_function(
638638
x_llm_b,
639639
self.llm_attention_mask,
@@ -662,7 +662,7 @@ def hybrid_prediction(
662662
def ls_ts_performance(self):
663663
beta_1, beta_2, reg = self.settings(
664664
x_train=self.x_train,
665-
y_train=self.y_train
665+
y_train=self.y_traing
666666
)
667667
spearman_r = self.spearmanr(
668668
self.y_test,

pypef/plm/esm_lora_tune.py

Lines changed: 0 additions & 311 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,13 @@
1818

1919
import logging
2020

21-
from pypef.plm.prosst_lora_tune import get_logits_from_full_seqs
2221
logger = logging.getLogger('pypef.llm.esm_lora_tune')
2322

2423
import torch
25-
import torch.nn.functional as F
2624
import numpy as np
2725
from scipy.stats import spearmanr
2826
from tqdm import tqdm
2927

30-
3128
from peft import LoraConfig, get_peft_model
3229
from transformers import logging as hf_logging
3330
hf_logging.set_verbosity_error()
@@ -143,314 +140,6 @@ def esm_infer(xs, attention_mask, model, device: str | None = None, verbose=Fals
143140
return torch.flatten(y_preds_total)
144141

145142

146-
def unmasked_wt_score(
147-
tokenized_sequences,
148-
attention_mask,
149-
wt_input_ids,
150-
model,
151-
train: bool = False,
152-
cut_special_tokens: bool = True, # assumption: cut first and last token
153-
device=None,
154-
**kwargs
155-
):
156-
if device is None:
157-
device = get_device()
158-
if wt_input_ids.dim() == 1:
159-
wt_input_ids = wt_input_ids.unsqueeze(0)
160-
structure_input_ids = kwargs.get("structure_input_ids", None)
161-
162-
attention_masks = torch.Tensor(np.full(
163-
shape=np.shape(wt_input_ids), fill_value=attention_mask)).to(torch.int64)
164-
if train:
165-
if structure_input_ids is not None:
166-
outputs = model(
167-
input_ids=wt_input_ids.to(device),
168-
attention_mask=attention_masks.to(device),
169-
ss_input_ids=structure_input_ids.to(device)
170-
)
171-
else:
172-
outputs = model(
173-
wt_input_ids.to(device),
174-
attention_masks.to(device),
175-
output_hidden_states=False
176-
)
177-
else:
178-
with torch.no_grad():
179-
if structure_input_ids is not None:
180-
outputs = model(
181-
input_ids=wt_input_ids.to(device),
182-
attention_mask=attention_masks.to(device),
183-
ss_input_ids=structure_input_ids.to(device)
184-
)
185-
else:
186-
outputs = model(
187-
wt_input_ids.to(device),
188-
attention_masks.to(device),
189-
output_hidden_states=False,
190-
)
191-
192-
logits = outputs.logits
193-
logits = logits.squeeze(0) # remove batch dim
194-
# Better make sure that special tokens are always removed / masked
195-
# and only pure amino acid sequence tokens are present / unmasked
196-
tokenized_seq_len = tokenized_sequences.shape[1]
197-
if cut_special_tokens:
198-
logits = logits[1:-1] # drop CLS/EOS
199-
tokenized_seq_len -= 2
200-
token_probs = torch.log_softmax(logits, dim=-1)
201-
assert tokenized_seq_len == token_probs.shape[0], (
202-
f"{tokenized_seq_len} != {token_probs.shape[0]}")
203-
204-
log_probs = []
205-
for tokenized_seq in tokenized_sequences:
206-
if cut_special_tokens:
207-
tokenized_seq = tokenized_seq[1:-1]
208-
209-
seq_lp = token_probs[
210-
torch.arange(tokenized_seq.shape[0], device=tokenized_seq.device),
211-
tokenized_seq
212-
].sum(dtype=torch.float64)
213-
214-
log_probs.append(seq_lp)
215-
216-
log_probs = torch.stack(log_probs)
217-
return log_probs
218-
219-
220-
def esm_mutation_only_mutation_masked_pll(
221-
tokenized_sequences: torch.Tensor, # (L,)
222-
wt_input_ids: torch.Tensor, # (L,)
223-
attention_mask: torch.Tensor, # (L,)
224-
model,
225-
mask_token_id: int,
226-
train: bool = False,
227-
device: str | None = None,
228-
verbose: bool = False,
229-
**kwargs
230-
):
231-
"""
232-
Correct mutation-only pseudo-log-likelihood for sequences.
233-
"""
234-
tokenized_sequences = tokenized_sequences.to(device)
235-
structure_input_ids = kwargs.get("structure_input_ids", None)
236-
if structure_input_ids is not None:
237-
assert structure_input_ids.shape[1] == tokenized_sequences.shape[1], (
238-
f"{structure_input_ids.shape[1]} != {tokenized_sequences.shape[1]}")
239-
structure_input_ids = structure_input_ids.to(device)
240-
if wt_input_ids.dim() == 2 and wt_input_ids.shape[0] == 1:
241-
wt_input_ids = wt_input_ids.squeeze(0)
242-
wt_input_ids = wt_input_ids.to(device)
243-
if attention_mask.dim() == 2 and attention_mask.shape[0] == 1:
244-
attention_mask = attention_mask.squeeze(0)
245-
attention_mask = attention_mask.to(device)
246-
plls = torch.empty(len(tokenized_sequences), device=device)
247-
for i, tokenized_seq in enumerate(tokenized_sequences):
248-
assert tokenized_seq.dim() == 1
249-
assert wt_input_ids.dim() == 1
250-
assert attention_mask.dim() == 1
251-
assert tokenized_seq.shape == wt_input_ids.shape == attention_mask.shape
252-
pll = torch.tensor(0.0, device=device)
253-
254-
# Identify mutated positions (exclude padding, CLS, EOS)
255-
diff = (tokenized_seq != wt_input_ids) & (attention_mask == 1)
256-
diff[0] = False
257-
diff[-1] = False
258-
259-
mutated_positions = diff.nonzero(as_tuple=False).flatten()
260-
# n_mutations = (tokenized_seq != wt_input_ids).sum().item()
261-
# Mutated positions: [int(m) - 1 for m in mutated_positions.cpu()] # Remove CLS token position
262-
263-
for pos in tqdm(
264-
mutated_positions,
265-
desc="Masked PLL (single sequence)",
266-
disable=not verbose
267-
):
268-
masked_input_ids = tokenized_seq.clone()
269-
masked_input_ids[pos] = mask_token_id
270-
if structure_input_ids is not None:
271-
masked_ss_input_ids = structure_input_ids.clone()
272-
masked_ss_input_ids[0, pos] = mask_token_id
273-
274-
if train:
275-
if structure_input_ids is not None:
276-
outputs = model(
277-
input_ids=masked_input_ids.unsqueeze(0),
278-
attention_mask=attention_mask.unsqueeze(0),
279-
ss_input_ids=masked_ss_input_ids # Check
280-
)
281-
else:
282-
outputs = model(
283-
input_ids=masked_input_ids.unsqueeze(0),
284-
attention_mask=attention_mask.unsqueeze(0),
285-
output_hidden_states=False
286-
)
287-
else:
288-
with torch.no_grad():
289-
if structure_input_ids is not None:
290-
outputs = model(
291-
input_ids=masked_input_ids.unsqueeze(0),
292-
attention_mask=attention_mask.unsqueeze(0),
293-
ss_input_ids=masked_ss_input_ids # Check
294-
)
295-
else:
296-
outputs = model(
297-
input_ids=masked_input_ids.unsqueeze(0),
298-
attention_mask=attention_mask.unsqueeze(0),
299-
output_hidden_states=False
300-
)
301-
logits = outputs.logits # (1, L, V)
302-
303-
log_probs = F.log_softmax(logits[0, pos], dim=-1)
304-
true_token = tokenized_seq[pos]
305-
306-
pll = pll + log_probs[true_token]
307-
308-
plls[i] = pll
309-
310-
return plls
311-
312-
313-
def esm_mutation_all_pos_masked_pll(
314-
tokenized_sequences: torch.Tensor, # (L,)
315-
attention_mask: torch.Tensor, # (L,)
316-
model,
317-
mask_token_id: int,
318-
train: bool = False,
319-
device: str | None = None,
320-
verbose: bool = False,
321-
**kwargs
322-
):
323-
"""
324-
Correct mutation-only pseudo-log-likelihood for sequences.
325-
"""
326-
structure_input_ids = kwargs.get("structure_input_ids", None)
327-
if structure_input_ids is not None:
328-
assert structure_input_ids.shape[1] == tokenized_sequences.shape[1], (
329-
f"{structure_input_ids.shape[1]} != {tokenized_sequences.shape[1]}")
330-
structure_input_ids = structure_input_ids.to(device)
331-
tokenized_sequences = tokenized_sequences.to(device)
332-
if attention_mask.dim() == 2 and attention_mask.shape[0] == 1:
333-
attention_mask = attention_mask.squeeze(0)
334-
attention_mask = attention_mask.to(device)
335-
plls = torch.empty(len(tokenized_sequences), device=device)
336-
for i, tokenized_seq in enumerate(tokenized_sequences):
337-
L = tokenized_seq.shape[0]
338-
pll = torch.tensor(0.0, device=device)
339-
340-
# Positions to score: all real tokens except CLS/EOS
341-
positions = (attention_mask == 1).nonzero(as_tuple=False).flatten()
342-
positions = positions[(positions != 0) & (positions != L - 1)]
343-
344-
345-
for pos in tqdm(
346-
positions,
347-
desc="Masked PLL (single sequence)",
348-
disable=not verbose
349-
):
350-
masked_input_ids = tokenized_seq.clone()
351-
masked_input_ids[pos] = mask_token_id
352-
353-
if structure_input_ids is not None:
354-
masked_ss_input_ids = structure_input_ids.clone()
355-
masked_ss_input_ids[0, pos] = mask_token_id
356-
357-
if train:
358-
if structure_input_ids is not None:
359-
outputs = model(
360-
input_ids=masked_input_ids.unsqueeze(0),
361-
attention_mask=attention_mask.unsqueeze(0),
362-
ss_input_ids=masked_ss_input_ids # Check
363-
)
364-
else:
365-
outputs = model(
366-
input_ids=masked_input_ids.unsqueeze(0),
367-
attention_mask=attention_mask.unsqueeze(0),
368-
output_hidden_states=False
369-
)
370-
else:
371-
with torch.no_grad():
372-
if structure_input_ids is not None:
373-
outputs = model(
374-
input_ids=masked_input_ids.unsqueeze(0),
375-
attention_mask=attention_mask.unsqueeze(0),
376-
ss_input_ids=masked_ss_input_ids # Check
377-
)
378-
else:
379-
outputs = model(
380-
input_ids=masked_input_ids.unsqueeze(0),
381-
attention_mask=attention_mask.unsqueeze(0),
382-
output_hidden_states=False
383-
)
384-
logits = outputs.logits # (1, L, V)
385-
386-
log_probs = F.log_softmax(logits[0, pos], dim=-1)
387-
true_token = tokenized_seq[pos]
388-
pll = pll + log_probs[true_token]
389-
390-
plls[i] = pll
391-
392-
return plls
393-
394-
395-
def plm_inference(
396-
xs,
397-
wt_input_ids,
398-
attention_mask,
399-
model,
400-
mask_token_id,
401-
inference_type='unmasked',
402-
wt_structure_input_ids=None,
403-
batch_size=5,
404-
train=False,
405-
device=None,
406-
verbose=False,
407-
):
408-
if device is None:
409-
device = get_device()
410-
411-
model = model.to(device)
412-
413-
if not isinstance(xs, torch.Tensor):
414-
xs = torch.tensor(xs, dtype=torch.long)
415-
416-
if not isinstance(attention_mask, torch.Tensor):
417-
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
418-
if inference_type == 'mutation-masking':
419-
inference_function = esm_mutation_only_mutation_masked_pll
420-
elif inference_type in ['full-masking', 'all-pos-masking']:
421-
inference_function = esm_mutation_all_pos_masked_pll
422-
elif inference_type in ['unmasked', 'wt-marginals']:
423-
inference_function = unmasked_wt_score
424-
else:
425-
raise SystemError("Choose between 'mutation-masking', 'unmasked', and 'full-masking'")
426-
427-
scores = []
428-
429-
xs_b = get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True)
430-
desc = f"Inference: {inference_type} batch (size={batch_size}) processing ({device.upper()})'"
431-
432-
pbar = tqdm(
433-
range(len(xs_b)),
434-
desc=desc,
435-
disable=not verbose
436-
)
437-
438-
for i in pbar:
439-
pll = inference_function(
440-
tokenized_sequences=torch.tensor(xs_b[i]),
441-
wt_input_ids=wt_input_ids,
442-
structure_input_ids=wt_structure_input_ids,
443-
attention_mask=attention_mask,
444-
model=model,
445-
mask_token_id=mask_token_id,
446-
train=train,
447-
device=device,
448-
verbose=False
449-
)
450-
scores.append(pll)
451-
return torch.cat(scores)
452-
453-
454143
def esm_train(
455144
xs, attention_mask, scores, loss_fn, model, optimizer, n_epochs=3,
456145
device: str | None = None, seed: int | None = None,

0 commit comments

Comments
 (0)