Skip to content

Commit 2158ac6

Browse files
committed
Getting GREMLIN sequence encodings in batches
of size 1000
1 parent 43ccba5 commit 2158ac6

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

pypef/dca/gremlin_inference.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,9 @@ def get_scores(self, seqs, v=None, w=None, v_idx=None, encode=False, h_wt_seq=0.
428428
seqs_int = self.seq2int(seqs)
429429

430430
try:
431-
if seqs_int.shape[-1] != len(v_idx):
432-
#logger.info(f'The input sequence length ({seqs_int.shape[-1]}) does not match the common gap-trimmed MSA sequence length ({len(v_idx)})!')
433-
seqs_int = seqs_int[..., v_idx]
434-
#logger.info(f'Updated shape: ({seqs_int.shape[-1]}) matches common MSA sequence length ({len(v_idx)}) now')
431+
if seqs_int.shape[-1] != len(v_idx): # The input sequence length ({seqs_int.shape[-1]})
432+
# does not match the common gap-trimmed MSA sequence length (len(v_idx)
433+
seqs_int = seqs_int[..., v_idx] # Shape matches common MSA sequence length (len(v_idx)) now
435434
except IndexError:
436435
raise SystemError(
437436
"The loaded GREMLIN parameter model does not match the input model "

pypef/llm/esm_lora_tune.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,25 @@ def corr_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
103103

104104

105105
def get_batches(a, dtype, batch_size=5,
106-
keep_numpy: bool = False, verbose: bool = False):
106+
keep_numpy: bool = False, keep_remaining=False, verbose: bool = False):
107107
a = np.asarray(a, dtype=dtype)
108108
orig_shape = np.shape(a)
109109
remaining = len(a) % batch_size
110110
if remaining != 0:
111111
a = a[:-remaining]
112+
a_remaining = a[-remaining:]
112113
if len(orig_shape) == 2:
113114
a = a.reshape(np.shape(a)[0] // batch_size, batch_size, np.shape(a)[1])
114115
else: # elif len(orig_shape) == 1:
115116
a = a.reshape(np.shape(a)[0] // batch_size, batch_size)
116117
new_shape = np.shape(a)
117118
if verbose:
118119
print(f'{orig_shape} -> {new_shape} (dropped {remaining})')
120+
if keep_remaining: # Returning a list
121+
a = list(a)
122+
print('Adding dropped back to batches as last batch...')
123+
a.append(a_remaining)
124+
return a
119125
if keep_numpy:
120126
return a
121127
return torch.Tensor(a).to(dtype)

scripts/ProteinGym_runs/results/dca_esm_and_hybrid_opt_results.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,5 @@ No.,Dataset,N_Variants,N_Max_Muts,Untrained_Performance_DCA,Untrained_Performanc
155155
165,FECA_ECOLI_Tsuboyama_2023_2D1U,1886,2,0.40408514501427906,0.3139309947296232,0.5738858975054831,0.4523259703428921,0.5052927011009388,0.6656143272779307,0.4769453507671686,0.5682842582146705,0.7580245116989018,0.6098073159278943,0.7656754659582231,nan,1886,1786,1686,886,346
156156
166,GCN4_YEAST_Staller_2018,2638,44,0.25011546899669806,-0.006027620813706041,0.22764968209696385,0.24358636362901392,0.241174939429571,nan,0.2544759868327667,0.36879030431219906,nan,0.2859991050918842,0.5125215331300973,nan,2638,2538,2438,1638,684
157157
167,GFP_AEQVI_Sarkisyan_2016,51714,15,0.6406366653494072,0.1336688728267403,0.6860965519817945,0.6422492276272843,0.6495590297325259,nan,0.6486786880849276,0.6360629034990605,nan,0.7463828216993244,0.7711781697848346,nan,51714,51614,51514,50714,8310
158+
168,GRB2_HUMAN_Faure_2021,63366,2,0.5258434005381363,0.5367412810228084,0.7216670654700138,0.5839763004949682,0.6852622116666458,0.697364904206813,0.6746840807821936,0.7103753714289707,0.7316071269759475,0.708765178129165,0.7742172713240596,0.7938850307943259,63366,63266,63166,62366,13344
159+
169,HECD1_HUMAN_Tsuboyama_2023_3DKM,5586,2,0.28623326119763287,0.2150000228393089,0.2406307470397028,0.5991137006846494,0.6520916880669738,0.7123845116845012,0.671240119874008,0.6713008248409833,0.715769538708621,0.6959516975099861,0.7840773924211748,0.7706682813598381,5586,5486,5386,4586,986

scripts/ProteinGym_runs/run_performance_tests_proteingym_hybrid_dca_llm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import gc
55
import time
66
import warnings
7+
import psutil
78
import json
9+
from tqdm import tqdm
810
import pandas as pd
911
import numpy as np
1012
import torch
@@ -60,6 +62,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
6062
get_vram()
6163
MAX_WT_SEQUENCE_LENGTH = 1000
6264
print(f"Maximum sequence length: {MAX_WT_SEQUENCE_LENGTH}")
65+
print(f"Loading LLM models into {device} device...")
6366
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
6467
prosst_vocab = prosst_tokenizer.get_vocab()
6568
prosst_base_model = prosst_base_model.to(device)
@@ -89,7 +92,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
8992
# print('Continuing (TODO: requires cut of PDB input struture residues)...')
9093
# continue
9194
# Getting % usage of virtual_memory (3rd field)
92-
import psutil;print(f'RAM used: {round(psutil.virtual_memory()[3]/1E9, 3)} '
95+
print(f'RAM used: {round(psutil.virtual_memory()[3]/1E9, 3)} '
9396
f'GB ({psutil.virtual_memory()[2]} %)')
9497
variant_fitness_data = pd.read_csv(csv_path, sep=',')
9598
print('N_variant-fitness-tuples:', np.shape(variant_fitness_data)[0])
@@ -142,7 +145,12 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
142145

143146
print('GREMLIN-DCA: optimization...')
144147
gremlin = GREMLIN(alignment=msa_path, opt_iter=100, optimize=True)
145-
x_dca = gremlin.collect_encoded_sequences(sequences)
148+
sequences_batched = get_batches(sequences, batch_size=1000,
149+
dtype=str, keep_remaining=True, verbose=True)
150+
x_dca = []
151+
for seq_b in tqdm(sequences_batched, desc="Getting GREMLIN sequence encodings"):
152+
for x in gremlin.collect_encoded_sequences(seq_b):
153+
x_dca.append(x)
146154
x_wt = gremlin.x_wt
147155
y_pred_dca = get_delta_e_statistical_model(x_dca, x_wt)
148156
print(f'DCA (unsupervised performance): {spearmanr(fitnesses, y_pred_dca)[0]:.3f}')

0 commit comments

Comments
 (0)