|
18 | 18 |
|
19 | 19 | import logging |
20 | 20 |
|
21 | | -from pypef.plm.prosst_lora_tune import get_logits_from_full_seqs |
22 | 21 | logger = logging.getLogger('pypef.llm.esm_lora_tune') |
23 | 22 |
|
24 | 23 | import torch |
25 | | -import torch.nn.functional as F |
26 | 24 | import numpy as np |
27 | 25 | from scipy.stats import spearmanr |
28 | 26 | from tqdm import tqdm |
29 | 27 |
|
30 | | - |
31 | 28 | from peft import LoraConfig, get_peft_model |
32 | 29 | from transformers import logging as hf_logging |
33 | 30 | hf_logging.set_verbosity_error() |
@@ -143,314 +140,6 @@ def esm_infer(xs, attention_mask, model, device: str | None = None, verbose=Fals |
143 | 140 | return torch.flatten(y_preds_total) |
144 | 141 |
|
145 | 142 |
|
146 | | -def unmasked_wt_score( |
147 | | - tokenized_sequences, |
148 | | - attention_mask, |
149 | | - wt_input_ids, |
150 | | - model, |
151 | | - train: bool = False, |
152 | | - cut_special_tokens: bool = True, # assumption: cut first and last token |
153 | | - device=None, |
154 | | - **kwargs |
155 | | - ): |
156 | | - if device is None: |
157 | | - device = get_device() |
158 | | - if wt_input_ids.dim() == 1: |
159 | | - wt_input_ids = wt_input_ids.unsqueeze(0) |
160 | | - structure_input_ids = kwargs.get("structure_input_ids", None) |
161 | | - |
162 | | - attention_masks = torch.Tensor(np.full( |
163 | | - shape=np.shape(wt_input_ids), fill_value=attention_mask)).to(torch.int64) |
164 | | - if train: |
165 | | - if structure_input_ids is not None: |
166 | | - outputs = model( |
167 | | - input_ids=wt_input_ids.to(device), |
168 | | - attention_mask=attention_masks.to(device), |
169 | | - ss_input_ids=structure_input_ids.to(device) |
170 | | - ) |
171 | | - else: |
172 | | - outputs = model( |
173 | | - wt_input_ids.to(device), |
174 | | - attention_masks.to(device), |
175 | | - output_hidden_states=False |
176 | | - ) |
177 | | - else: |
178 | | - with torch.no_grad(): |
179 | | - if structure_input_ids is not None: |
180 | | - outputs = model( |
181 | | - input_ids=wt_input_ids.to(device), |
182 | | - attention_mask=attention_masks.to(device), |
183 | | - ss_input_ids=structure_input_ids.to(device) |
184 | | - ) |
185 | | - else: |
186 | | - outputs = model( |
187 | | - wt_input_ids.to(device), |
188 | | - attention_masks.to(device), |
189 | | - output_hidden_states=False, |
190 | | - ) |
191 | | - |
192 | | - logits = outputs.logits |
193 | | - logits = logits.squeeze(0) # remove batch dim |
194 | | - # Better make sure that special tokens are always removed / masked |
195 | | - # and only pure amino acid sequence tokens are present / unmasked |
196 | | - tokenized_seq_len = tokenized_sequences.shape[1] |
197 | | - if cut_special_tokens: |
198 | | - logits = logits[1:-1] # drop CLS/EOS |
199 | | - tokenized_seq_len -= 2 |
200 | | - token_probs = torch.log_softmax(logits, dim=-1) |
201 | | - assert tokenized_seq_len == token_probs.shape[0], ( |
202 | | - f"{tokenized_seq_len} != {token_probs.shape[0]}") |
203 | | - |
204 | | - log_probs = [] |
205 | | - for tokenized_seq in tokenized_sequences: |
206 | | - if cut_special_tokens: |
207 | | - tokenized_seq = tokenized_seq[1:-1] |
208 | | - |
209 | | - seq_lp = token_probs[ |
210 | | - torch.arange(tokenized_seq.shape[0], device=tokenized_seq.device), |
211 | | - tokenized_seq |
212 | | - ].sum(dtype=torch.float64) |
213 | | - |
214 | | - log_probs.append(seq_lp) |
215 | | - |
216 | | - log_probs = torch.stack(log_probs) |
217 | | - return log_probs |
218 | | - |
219 | | - |
220 | | -def esm_mutation_only_mutation_masked_pll( |
221 | | - tokenized_sequences: torch.Tensor, # (L,) |
222 | | - wt_input_ids: torch.Tensor, # (L,) |
223 | | - attention_mask: torch.Tensor, # (L,) |
224 | | - model, |
225 | | - mask_token_id: int, |
226 | | - train: bool = False, |
227 | | - device: str | None = None, |
228 | | - verbose: bool = False, |
229 | | - **kwargs |
230 | | -): |
231 | | - """ |
232 | | - Correct mutation-only pseudo-log-likelihood for sequences. |
233 | | - """ |
234 | | - tokenized_sequences = tokenized_sequences.to(device) |
235 | | - structure_input_ids = kwargs.get("structure_input_ids", None) |
236 | | - if structure_input_ids is not None: |
237 | | - assert structure_input_ids.shape[1] == tokenized_sequences.shape[1], ( |
238 | | - f"{structure_input_ids.shape[1]} != {tokenized_sequences.shape[1]}") |
239 | | - structure_input_ids = structure_input_ids.to(device) |
240 | | - if wt_input_ids.dim() == 2 and wt_input_ids.shape[0] == 1: |
241 | | - wt_input_ids = wt_input_ids.squeeze(0) |
242 | | - wt_input_ids = wt_input_ids.to(device) |
243 | | - if attention_mask.dim() == 2 and attention_mask.shape[0] == 1: |
244 | | - attention_mask = attention_mask.squeeze(0) |
245 | | - attention_mask = attention_mask.to(device) |
246 | | - plls = torch.empty(len(tokenized_sequences), device=device) |
247 | | - for i, tokenized_seq in enumerate(tokenized_sequences): |
248 | | - assert tokenized_seq.dim() == 1 |
249 | | - assert wt_input_ids.dim() == 1 |
250 | | - assert attention_mask.dim() == 1 |
251 | | - assert tokenized_seq.shape == wt_input_ids.shape == attention_mask.shape |
252 | | - pll = torch.tensor(0.0, device=device) |
253 | | - |
254 | | - # Identify mutated positions (exclude padding, CLS, EOS) |
255 | | - diff = (tokenized_seq != wt_input_ids) & (attention_mask == 1) |
256 | | - diff[0] = False |
257 | | - diff[-1] = False |
258 | | - |
259 | | - mutated_positions = diff.nonzero(as_tuple=False).flatten() |
260 | | - # n_mutations = (tokenized_seq != wt_input_ids).sum().item() |
261 | | - # Mutated positions: [int(m) - 1 for m in mutated_positions.cpu()] # Remove CLS token position |
262 | | - |
263 | | - for pos in tqdm( |
264 | | - mutated_positions, |
265 | | - desc="Masked PLL (single sequence)", |
266 | | - disable=not verbose |
267 | | - ): |
268 | | - masked_input_ids = tokenized_seq.clone() |
269 | | - masked_input_ids[pos] = mask_token_id |
270 | | - if structure_input_ids is not None: |
271 | | - masked_ss_input_ids = structure_input_ids.clone() |
272 | | - masked_ss_input_ids[0, pos] = mask_token_id |
273 | | - |
274 | | - if train: |
275 | | - if structure_input_ids is not None: |
276 | | - outputs = model( |
277 | | - input_ids=masked_input_ids.unsqueeze(0), |
278 | | - attention_mask=attention_mask.unsqueeze(0), |
279 | | - ss_input_ids=masked_ss_input_ids # Check |
280 | | - ) |
281 | | - else: |
282 | | - outputs = model( |
283 | | - input_ids=masked_input_ids.unsqueeze(0), |
284 | | - attention_mask=attention_mask.unsqueeze(0), |
285 | | - output_hidden_states=False |
286 | | - ) |
287 | | - else: |
288 | | - with torch.no_grad(): |
289 | | - if structure_input_ids is not None: |
290 | | - outputs = model( |
291 | | - input_ids=masked_input_ids.unsqueeze(0), |
292 | | - attention_mask=attention_mask.unsqueeze(0), |
293 | | - ss_input_ids=masked_ss_input_ids # Check |
294 | | - ) |
295 | | - else: |
296 | | - outputs = model( |
297 | | - input_ids=masked_input_ids.unsqueeze(0), |
298 | | - attention_mask=attention_mask.unsqueeze(0), |
299 | | - output_hidden_states=False |
300 | | - ) |
301 | | - logits = outputs.logits # (1, L, V) |
302 | | - |
303 | | - log_probs = F.log_softmax(logits[0, pos], dim=-1) |
304 | | - true_token = tokenized_seq[pos] |
305 | | - |
306 | | - pll = pll + log_probs[true_token] |
307 | | - |
308 | | - plls[i] = pll |
309 | | - |
310 | | - return plls |
311 | | - |
312 | | - |
313 | | -def esm_mutation_all_pos_masked_pll( |
314 | | - tokenized_sequences: torch.Tensor, # (L,) |
315 | | - attention_mask: torch.Tensor, # (L,) |
316 | | - model, |
317 | | - mask_token_id: int, |
318 | | - train: bool = False, |
319 | | - device: str | None = None, |
320 | | - verbose: bool = False, |
321 | | - **kwargs |
322 | | -): |
323 | | - """ |
324 | | - Correct mutation-only pseudo-log-likelihood for sequences. |
325 | | - """ |
326 | | - structure_input_ids = kwargs.get("structure_input_ids", None) |
327 | | - if structure_input_ids is not None: |
328 | | - assert structure_input_ids.shape[1] == tokenized_sequences.shape[1], ( |
329 | | - f"{structure_input_ids.shape[1]} != {tokenized_sequences.shape[1]}") |
330 | | - structure_input_ids = structure_input_ids.to(device) |
331 | | - tokenized_sequences = tokenized_sequences.to(device) |
332 | | - if attention_mask.dim() == 2 and attention_mask.shape[0] == 1: |
333 | | - attention_mask = attention_mask.squeeze(0) |
334 | | - attention_mask = attention_mask.to(device) |
335 | | - plls = torch.empty(len(tokenized_sequences), device=device) |
336 | | - for i, tokenized_seq in enumerate(tokenized_sequences): |
337 | | - L = tokenized_seq.shape[0] |
338 | | - pll = torch.tensor(0.0, device=device) |
339 | | - |
340 | | - # Positions to score: all real tokens except CLS/EOS |
341 | | - positions = (attention_mask == 1).nonzero(as_tuple=False).flatten() |
342 | | - positions = positions[(positions != 0) & (positions != L - 1)] |
343 | | - |
344 | | - |
345 | | - for pos in tqdm( |
346 | | - positions, |
347 | | - desc="Masked PLL (single sequence)", |
348 | | - disable=not verbose |
349 | | - ): |
350 | | - masked_input_ids = tokenized_seq.clone() |
351 | | - masked_input_ids[pos] = mask_token_id |
352 | | - |
353 | | - if structure_input_ids is not None: |
354 | | - masked_ss_input_ids = structure_input_ids.clone() |
355 | | - masked_ss_input_ids[0, pos] = mask_token_id |
356 | | - |
357 | | - if train: |
358 | | - if structure_input_ids is not None: |
359 | | - outputs = model( |
360 | | - input_ids=masked_input_ids.unsqueeze(0), |
361 | | - attention_mask=attention_mask.unsqueeze(0), |
362 | | - ss_input_ids=masked_ss_input_ids # Check |
363 | | - ) |
364 | | - else: |
365 | | - outputs = model( |
366 | | - input_ids=masked_input_ids.unsqueeze(0), |
367 | | - attention_mask=attention_mask.unsqueeze(0), |
368 | | - output_hidden_states=False |
369 | | - ) |
370 | | - else: |
371 | | - with torch.no_grad(): |
372 | | - if structure_input_ids is not None: |
373 | | - outputs = model( |
374 | | - input_ids=masked_input_ids.unsqueeze(0), |
375 | | - attention_mask=attention_mask.unsqueeze(0), |
376 | | - ss_input_ids=masked_ss_input_ids # Check |
377 | | - ) |
378 | | - else: |
379 | | - outputs = model( |
380 | | - input_ids=masked_input_ids.unsqueeze(0), |
381 | | - attention_mask=attention_mask.unsqueeze(0), |
382 | | - output_hidden_states=False |
383 | | - ) |
384 | | - logits = outputs.logits # (1, L, V) |
385 | | - |
386 | | - log_probs = F.log_softmax(logits[0, pos], dim=-1) |
387 | | - true_token = tokenized_seq[pos] |
388 | | - pll = pll + log_probs[true_token] |
389 | | - |
390 | | - plls[i] = pll |
391 | | - |
392 | | - return plls |
393 | | - |
394 | | - |
395 | | -def plm_inference( |
396 | | - xs, |
397 | | - wt_input_ids, |
398 | | - attention_mask, |
399 | | - model, |
400 | | - mask_token_id, |
401 | | - inference_type='unmasked', |
402 | | - wt_structure_input_ids=None, |
403 | | - batch_size=5, |
404 | | - train=False, |
405 | | - device=None, |
406 | | - verbose=False, |
407 | | -): |
408 | | - if device is None: |
409 | | - device = get_device() |
410 | | - |
411 | | - model = model.to(device) |
412 | | - |
413 | | - if not isinstance(xs, torch.Tensor): |
414 | | - xs = torch.tensor(xs, dtype=torch.long) |
415 | | - |
416 | | - if not isinstance(attention_mask, torch.Tensor): |
417 | | - attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
418 | | - if inference_type == 'mutation-masking': |
419 | | - inference_function = esm_mutation_only_mutation_masked_pll |
420 | | - elif inference_type in ['full-masking', 'all-pos-masking']: |
421 | | - inference_function = esm_mutation_all_pos_masked_pll |
422 | | - elif inference_type in ['unmasked', 'wt-marginals']: |
423 | | - inference_function = unmasked_wt_score |
424 | | - else: |
425 | | - raise SystemError("Choose between 'mutation-masking', 'unmasked', and 'full-masking'") |
426 | | - |
427 | | - scores = [] |
428 | | - |
429 | | - xs_b = get_batches(xs, dtype=int, batch_size=batch_size, keep_remaining=True, verbose=True) |
430 | | - desc = f"Inference: {inference_type} batch (size={batch_size}) processing ({device.upper()})'" |
431 | | - |
432 | | - pbar = tqdm( |
433 | | - range(len(xs_b)), |
434 | | - desc=desc, |
435 | | - disable=not verbose |
436 | | - ) |
437 | | - |
438 | | - for i in pbar: |
439 | | - pll = inference_function( |
440 | | - tokenized_sequences=torch.tensor(xs_b[i]), |
441 | | - wt_input_ids=wt_input_ids, |
442 | | - structure_input_ids=wt_structure_input_ids, |
443 | | - attention_mask=attention_mask, |
444 | | - model=model, |
445 | | - mask_token_id=mask_token_id, |
446 | | - train=train, |
447 | | - device=device, |
448 | | - verbose=False |
449 | | - ) |
450 | | - scores.append(pll) |
451 | | - return torch.cat(scores) |
452 | | - |
453 | | - |
454 | 143 | def esm_train( |
455 | 144 | xs, attention_mask, scores, loss_fn, model, optimizer, n_epochs=3, |
456 | 145 | device: str | None = None, seed: int | None = None, |
|
0 commit comments