1717from __future__ import annotations
1818
1919import logging
20- from time import sleep
20+
21+ from pypef .plm .prosst_lora_tune import get_logits_from_full_seqs
2122logger = logging .getLogger ('pypef.llm.esm_lora_tune' )
2223
2324import torch
@@ -142,35 +143,37 @@ def esm_infer(xs, attention_mask, model, device: str | None = None, verbose=Fals
142143 return torch .flatten (y_preds_total )
143144
144145
145- def esm_unmasked_reconstruction_score (
146+ def esm_unmasked_wt_score (
146147 tokenized_sequences ,
147148 attention_mask ,
149+ wt_input_ids ,
148150 model ,
149151 train : bool = False ,
150152 device = None ,
151- ** kws
153+ ** kwargs
152154 ):
153155 if device is None :
154156 device = get_device ()
157+ wt_input_ids = wt_input_ids .unsqueeze (0 )
155158 attention_masks = torch .Tensor (np .full (
156- shape = np .shape (tokenized_sequences ), fill_value = attention_mask )).to (torch .int64 )
159+ shape = np .shape (wt_input_ids ), fill_value = attention_mask )).to (torch .int64 )
157160 if train :
158- with torch .no_grad ():
159- outputs = model (tokenized_sequences .to (device ), attention_masks .to (device ),
160- output_hidden_states = False )
161+ outputs = model (wt_input_ids .to (device ), attention_masks .to (device ),
162+ output_hidden_states = False )
161163 else :
162- outputs = model (tokenized_sequences .to (device ), attention_masks .to (device ),
164+ with torch .no_grad ():
165+ outputs = model (wt_input_ids .to (device ), attention_masks .to (device ),
163166 output_hidden_states = False )
164167 logits = outputs .logits
165- token_probs = torch .log_softmax (logits , dim = - 1 )
166- for i_s , sequence in enumerate (tokenized_sequences ):
167- for i_aa , aa in enumerate (sequence ):
168+ token_probs = torch .log_softmax (logits , dim = - 1 ). squeeze ( 0 )
169+ for i_s , tokenized_seq in enumerate (tokenized_sequences ):
170+ for i_aa , aa in enumerate (tokenized_seq ):
168171 # alternative: use Tensor.index_select() function
169172 if i_aa == 0 :
170- seq_log_probs = token_probs [i_s , i_aa , aa ].reshape (1 )
173+ seq_log_probs = token_probs [i_aa , aa ].reshape (1 )
171174 else :
172175 seq_log_probs = torch .cat (
173- (seq_log_probs , token_probs [i_s , i_aa , aa ].reshape (1 )), 0 )
176+ (seq_log_probs , token_probs [i_aa , aa ].reshape (1 )), 0 )
174177 if i_s == 0 :
175178 log_probs = torch .sum (torch .Tensor (seq_log_probs )).reshape (1 )
176179 else :
@@ -179,124 +182,6 @@ def esm_unmasked_reconstruction_score(
179182 return log_probs
180183
181184
182- def esm_masked_pll (
183- input_ids : torch .Tensor , # (B, L)
184- attention_mask : torch .Tensor , # (B, L)
185- model ,
186- mask_token_id : int ,
187- device : str | None = None ,
188- verbose : bool = False ,
189- ):
190- """
191- Compute true pseudo-log-likelihood (PLL) for an MLM (ESM).
192-
193- Returns:
194- pll_scores: torch.Tensor of shape (B,)
195- """
196- if device is None :
197- device = next (model .parameters ()).device
198-
199- input_ids = input_ids .to (device )
200- attention_mask = attention_mask .to (device )
201-
202- B , L = input_ids .shape
203- pll_scores = torch .zeros (B , device = device )
204-
205- model .eval ()
206-
207- for pos in tqdm (
208- range (L ),
209- desc = "ESM masked PLL" ,
210- disable = not verbose
211- ):
212- # Skip padding positions (position padding for all sequences in the batch)
213- if attention_mask [:, pos ].sum () == 0 :
214- continue
215-
216- # Clone and mask position `pos`
217- masked_input_ids = input_ids .clone ()
218- masked_input_ids [:, pos ] = mask_token_id
219-
220- with torch .no_grad ():
221- outputs = model (
222- input_ids = masked_input_ids ,
223- attention_mask = attention_mask ,
224- )
225-
226- logits = outputs .logits # (B, L, V)
227-
228- # Log-probabilities at masked position
229- log_probs = F .log_softmax (logits [:, pos , :], dim = - 1 )
230-
231- # True tokens at this position
232- true_tokens = input_ids [:, pos ]
233-
234- # Gather log-prob of the true token
235- token_log_probs = log_probs .gather (
236- dim = 1 ,
237- index = true_tokens .unsqueeze (1 )
238- ).squeeze (1 )
239-
240- # Only count non-padding
241- pll_scores += token_log_probs * attention_mask [:, pos ]
242-
243- return pll_scores
244-
245-
246- def esm_infer_masked_pll (
247- xs ,
248- attention_mask ,
249- model ,
250- mask_token_id ,
251- batch_size : int = 4 ,
252- device : str | None = None ,
253- verbose : bool = False ,
254- ):
255- if device is None :
256- device = get_device ()
257-
258- model = model .to (device )
259- model .eval ()
260-
261- if not isinstance (xs , torch .Tensor ):
262- xs = torch .tensor (xs , dtype = torch .long )
263-
264- if not isinstance (attention_mask , torch .Tensor ):
265- attention_mask = torch .tensor (attention_mask , dtype = torch .long )
266-
267- xs = xs .to (device )
268-
269- # Expand mask to (N, L) if needed
270- if attention_mask .dim () == 1 :
271- attention_mask = attention_mask .unsqueeze (0 ).expand (xs .shape [0 ], - 1 )
272-
273- attention_mask = attention_mask .to (device )
274-
275- pll_all = []
276-
277- for i in tqdm (
278- range (0 , xs .shape [0 ], batch_size ),
279- desc = "ESM PLL inference" ,
280- disable = not verbose ,
281- ):
282- xs_b = xs [i :i + batch_size ]
283- am_b = attention_mask [i :i + batch_size ]
284-
285- pll_b = esm_masked_pll (
286- input_ids = xs_b ,
287- attention_mask = am_b ,
288- model = model ,
289- mask_token_id = mask_token_id ,
290- device = device ,
291- verbose = False ,
292- )
293-
294- pll_all .append (pll_b .cpu ())
295-
296- return torch .cat (pll_all )
297-
298-
299-
300185def esm_mutation_only_mutation_masked_pll (
301186 tokenized_sequences : torch .Tensor , # (L,)
302187 wt_input_ids : torch .Tensor , # (L,)
@@ -306,6 +191,7 @@ def esm_mutation_only_mutation_masked_pll(
306191 train : bool = False ,
307192 device : str | None = None ,
308193 verbose : bool = False ,
194+ ** kwargs
309195):
310196 """
311197 Correct mutation-only pseudo-log-likelihood for ONE sequence.
@@ -335,16 +221,16 @@ def esm_mutation_only_mutation_masked_pll(
335221 masked_input_ids = tokenized_seq .clone ()
336222 masked_input_ids [pos ] = mask_token_id
337223 if train :
224+ outputs = model (
225+ input_ids = masked_input_ids .unsqueeze (0 ),
226+ attention_mask = attention_mask .unsqueeze (0 ),
227+ )
228+ else :
338229 with torch .no_grad ():
339230 outputs = model (
340231 input_ids = masked_input_ids .unsqueeze (0 ),
341232 attention_mask = attention_mask .unsqueeze (0 ),
342233 )
343- else :
344- outputs = model (
345- input_ids = masked_input_ids .unsqueeze (0 ),
346- attention_mask = attention_mask .unsqueeze (0 ),
347- )
348234 logits = outputs .logits # (1, L, V)
349235
350236 log_probs = F .log_softmax (logits [0 , pos ], dim = - 1 )
@@ -393,16 +279,16 @@ def esm_mutation_all_pos_masked_pll(
393279 masked_input_ids [pos ] = mask_token_id
394280
395281 if train :
282+ outputs = model (
283+ input_ids = masked_input_ids .unsqueeze (0 ),
284+ attention_mask = attention_mask .unsqueeze (0 ),
285+ )
286+ else :
396287 with torch .no_grad ():
397288 outputs = model (
398289 input_ids = masked_input_ids .unsqueeze (0 ),
399290 attention_mask = attention_mask .unsqueeze (0 ),
400291 )
401- else :
402- outputs = model (
403- input_ids = masked_input_ids .unsqueeze (0 ),
404- attention_mask = attention_mask .unsqueeze (0 ),
405- )
406292 logits = outputs .logits # (1, L, V)
407293
408294 log_probs = F .log_softmax (logits [0 , pos ], dim = - 1 )
@@ -437,13 +323,16 @@ def esm_infer_pll(
437323
438324 if not isinstance (attention_mask , torch .Tensor ):
439325 attention_mask = torch .tensor (attention_mask , dtype = torch .long )
440-
441- if inference_type == 'mutation_masking ' :
326+ wt_structure_input_ids = None
327+ if inference_type == 'mutation-masking ' :
442328 inference_function = esm_mutation_only_mutation_masked_pll
443- elif inference_type == 'full_masking' :
329+ elif inference_type in [ 'full-masking' , 'all-pos-masking' ] :
444330 inference_function = esm_mutation_all_pos_masked_pll
445- elif inference_type == 'unmasked' :
446- inference_function = esm_unmasked_reconstruction_score
331+ elif inference_type in ['unmasked' , 'wt-marginals' ]:
332+ inference_function = esm_unmasked_wt_score
333+ elif inference_type == 'prosst' :
334+ wt_input_ids , wt_structure_input_ids = wt_input_ids
335+ inference_function = esm_unmasked_wt_score
447336 else :
448337 raise SystemError ("Choose between 'mutation_masking', 'unmasked', and 'full_masking'" )
449338
@@ -462,6 +351,7 @@ def esm_infer_pll(
462351 pll = inference_function (
463352 tokenized_sequences = torch .tensor (xs_b [i ]),
464353 wt_input_ids = wt_input_ids ,
354+ structure_input_ids = wt_structure_input_ids ,
465355 attention_mask = attention_mask ,
466356 model = model ,
467357 mask_token_id = mask_token_id ,
0 commit comments