Skip to content

Commit 03b88f8

Browse files
committed
Update gremlin (dev)_ implement msa_start & msa_end
for trimming MSAs if needed
1 parent 40a7848 commit 03b88f8

File tree

5 files changed

+70
-19
lines changed

5 files changed

+70
-19
lines changed

pypef/dca/gremlin_inference.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
from tqdm import tqdm
5050
import torch
5151

52+
from pypef.llm.utils import get_batches
53+
from pypef.utils.variant_data import get_mismatches
54+
5255

5356
class GREMLIN:
5457
"""
@@ -65,6 +68,8 @@ def __init__(
6568
eff_cutoff=0.8,
6669
opt_iter=100,
6770
max_msa_seqs: int | None = 10000,
71+
msa_start: None | int = None,
72+
msa_end: None | int = None,
6873
seqs: list[str] | np.ndarray[str] | None =None,
6974
device: str | None = None
7075
):
@@ -98,12 +103,20 @@ def __init__(
98103
else:
99104
self.max_msa_seqs = max_msa_seqs
100105
self.states = len(self.char_alphabet)
106+
self.msa_start = msa_start
107+
if msa_end == 0:
108+
msa_end = None
109+
self.msa_end = msa_end
101110
logger.info('Loading MSA...')
102111
if seqs is None:
103112
self.seqs, self.seq_ids = self.get_sequences_from_msa(alignment)
104113
else:
105114
self.seqs = seqs
106115
self.seq_ids = np.array([n for n in range(len(self.seqs))])
116+
self.first_msa_seq = self.seqs[0]
117+
if self.msa_start is not None or self.msa_end is not None:
118+
logger.info(f'Trimmed sequence length.. first sequence is printed here as '
119+
f'example (Length: {len(self.first_msa_seq)}): {self.first_msa_seq}')
107120
logger.info(f'Found {len(self.seqs)} sequences in the MSA...')
108121
self.msa_ori = self.get_msa_ori()
109122
logger.info(f'MSA shape: {np.shape(self.msa_ori)}')
@@ -153,7 +166,14 @@ def get_sequences_from_msa(self, msa_file: str):
153166
with open(msa_file, 'r') as fh:
154167
alignment = AlignIO.read(fh, "fasta")
155168
for record in alignment:
156-
sequences.append(str(record.seq))
169+
seq = str(record.seq)
170+
if self.msa_start is not None and self.msa_end is not None:
171+
seq = seq[self.msa_start:self.msa_end]
172+
elif self.msa_start is not None:
173+
seq = seq[self.msa_start:]
174+
elif self.msa_end is not None:
175+
seq = seq[:self.msa_end]
176+
sequences.append(seq)
157177
seq_ids.append(str(record.id))
158178
assert len(sequences) == len(seq_ids), f"{len(sequences)}, {len(seq_ids)}"
159179
return np.array(sequences), np.array(seq_ids)
@@ -353,14 +373,14 @@ def run_optimization(self):
353373

354374
self.mt_v, self.vt_v = torch.zeros_like(self.v), torch.zeros_like(self.v)
355375
self.mt_w, self.vt_w = torch.zeros_like(self.w), torch.zeros_like(self.w)
356-
logger.info(f'Initial loss: {self._loss()}')
376+
logger.info(f'Initial loss: {self._loss():.5f}')
357377
for i in range(self.opt_iter):
358378
self.opt_adam_step()
359379
try:
360380
if (i + 1) % int(self.opt_iter / 10) == 0:
361-
logger.info(f'Loss step {i + 1}: {self._loss()}')
381+
logger.info(f'Loss step {i + 1}: {self._loss():.5f}')
362382
except ZeroDivisionError:
363-
logger.info(f'Loss step {i + 1}: {self._loss()}')
383+
logger.info(f'Loss step {i + 1}: {self._loss():.5f}')
364384

365385
self.v = self.v.detach().cpu().numpy()
366386
self.w = self.w.detach().cpu().numpy()
@@ -416,7 +436,20 @@ def get_scores(self, seqs, v=None, w=None, v_idx=None, encode=False, h_wt_seq=0.
416436
if v_idx is None:
417437
v_idx = self.v_idx
418438
seqs_int = self.seq2int(seqs)
419-
439+
wt_seq_len = len(self.wt_seq)
440+
#if np.shape(seqs_int)[1] != wt_seq_len:
441+
# raise RuntimeError(
442+
# f"Input sequence shape (length: {np.shape(seqs_int)[1]}) does not match GREMLIN "
443+
# f"MSA shape (common sequence length: {wt_seq_len}) inferred from the MSA."
444+
# )
445+
# Check nums of mutations to MSA first/WT sequence and gives warning if too apart from MSA seq
446+
for i, seq in enumerate(seqs):
447+
n_mismatches, mismatches = get_mismatches(self.wt_seq, seq)
448+
if n_mismatches / wt_seq_len > 0.05:
449+
logger.warning(
450+
f"Sequence {mismatches} contains more than 5% sequence mismatches to the "
451+
f"first MSA/\"WT\" sequence. Effect predictions will likely be incorrect!"
452+
)
420453
try:
421454
if seqs_int.shape[-1] != len(v_idx): # The input sequence length ({seqs_int.shape[-1]})
422455
# does not match the common gap-trimmed MSA sequence length (len(v_idx)
@@ -471,8 +504,16 @@ def collect_encoded_sequences(self, seqs, v=None, w=None, v_idx=None):
471504
Wrapper function for encoding input sequences using the self.get_scores
472505
function with encode set to True.
473506
"""
474-
xs = self.get_scores(seqs, v, w, v_idx, encode=True)
475-
return xs
507+
xs = []
508+
sequences_batched = get_batches(
509+
seqs, batch_size=1000, dtype=str,
510+
keep_remaining=True, verbose=True
511+
)
512+
sequences_batched = np.atleast_2d(sequences_batched)
513+
514+
for seq_batch in sequences_batched:
515+
xs.append(self.get_scores(seq_batch, v, w, v_idx, encode=True))
516+
return xs[0]
476517

477518
@staticmethod
478519
def normalize(apc_mat):
@@ -691,7 +732,7 @@ def save_corr_csv(gremlin: GREMLIN, min_distance: int = 0, sort_by: str = 'apc')
691732
)
692733
df_mtx_sorted_mindist.to_csv(f"coevolution_{sort_by}_sorted.csv", sep=',')
693734
logger.info(f"Saved coevolution CSV data as "
694-
f"{os.path.abspath(f'coevolution_{sort_by}_sorted.csv')}")
735+
f"{os.path.abspath(f'coevolution_{sort_by}_sorted.csv')}")
695736

696737

697738
def plot_predicted_ssm(gremlin: GREMLIN):

pypef/hybrid/hybrid_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def train_llm(self):
447447
self.input_ids,
448448
self.llm_attention_mask,
449449
self.structure_input_ids,
450-
n_epochs=50,
450+
n_epochs=50,
451451
device=self.device,
452452
verbose=self.verbose
453453
)
@@ -641,7 +641,6 @@ def ls_ts_performance(self):
641641
return spearman_r, reg, beta_1, beta_2
642642

643643

644-
645644
"""
646645
###########################################################################################
647646
# Below: Some helper functions that call or are dependent on the DCALLMHybridModel class. #

pypef/llm/prosst_lora_tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def load_model(model, filename):
128128
def prosst_train(
129129
x_sequence_batches, score_batches, loss_fn, model, optimizer,
130130
input_ids, attention_mask, structure_input_ids,
131-
n_epochs=3, device: str | None = None, seed: int | None = None,
131+
n_epochs=50, device: str | None = None, seed: int | None = None,
132132
early_stop: int = 50, verbose: bool = True):
133133
if seed is not None:
134134
torch.manual_seed(seed)
@@ -139,7 +139,7 @@ def prosst_train(
139139
x_sequence_batches = x_sequence_batches.to(device)
140140
score_batches = score_batches.to(device)
141141
pbar_epochs = tqdm(range(1, n_epochs + 1), disable=not verbose)
142-
epoch_spearman_1 = 0.0
142+
epoch_spearman_1 = -1.0
143143
did_not_improve_counter = 0
144144
best_model = None
145145
best_model_epoch = np.nan
@@ -177,7 +177,7 @@ def prosst_train(
177177
f"Y_true: {score_batches.cpu().numpy().flatten()}, "
178178
f"Y_pred: {np.array(y_preds_detached)}"
179179
)
180-
if epoch_spearman_2 > epoch_spearman_1:
180+
if epoch_spearman_2 > epoch_spearman_1 or epoch == 0:
181181
if best_model is not None:
182182
if os.path.isfile(best_model):
183183
os.remove(best_model)

pypef/llm/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ def get_batches(a, dtype, batch_size=5,
3535
a_remaining = a[-remaining:]
3636
else:
3737
logger.info(f"Batch size greater than or equal to total array length: "
38-
f"returning full array (of shape: {np.shape(a)})...")
38+
f"returning full array (of shape: {np.shape(a)})...")
3939
if keep_remaining:
40-
return list(a)
41-
else:
4240
return a
4341
if len(orig_shape) == 2:
4442
a = a.reshape(np.shape(a)[0] // batch_size, batch_size, np.shape(a)[1])
@@ -47,10 +45,9 @@ def get_batches(a, dtype, batch_size=5,
4745
new_shape = np.shape(a)
4846
if verbose:
4947
logger.info(f'{orig_shape} -> {new_shape} (dropped {remaining})')
50-
if keep_remaining: # Returning a list
51-
a = list(a)
48+
if keep_remaining:
5249
logger.info('Adding dropped back to batches as last batch...')
53-
a.append(a_remaining)
50+
a = np.append(a, a_remaining)
5451
return a
5552
if keep_numpy:
5653
return a

pypef/utils/variant_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import numpy as np
77
import pandas as pd
8+
import warnings
89

910
import logging
1011
logger = logging.getLogger('pypef.utils.variant_data')
@@ -462,3 +463,16 @@ def read_csv_and_shift_pos_ints(
462463
data = np.array([new_col, column_2]).T
463464
new_df = pd.DataFrame(data, columns=['variant', 'fitness'])
464465
new_df.to_csv(infile[:-4] + '_new' + infile[-4:], sep=';', index=False)
466+
467+
468+
def get_mismatches(seq_a: str, seq_b: str):
469+
n = 0
470+
mismatches = ""
471+
if len(seq_a) != len(seq_b):
472+
logger.warning("Sequence length's do not match!")
473+
raise RuntimeError(f"{seq_a}\n{len(seq_a)}\n{seq_b}\n{len(seq_b)}")
474+
for i_a, aa in enumerate(seq_a):
475+
if aa != seq_b[i_a]:
476+
mismatches += f"{aa}{i_a + 1}{seq_b[i_a]},"
477+
n += 1
478+
return n, mismatches[:-1]

0 commit comments

Comments
 (0)