Skip to content

Commit 4c5986a

Browse files
committed
Add extract_emb option
1 parent a88e69b commit 4c5986a

File tree

4 files changed

+144
-271
lines changed

4 files changed

+144
-271
lines changed

pypef/gaussian_process/gp_prosst_test.py

Lines changed: 64 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
2+
3+
"""
4+
Gaussian process optimization similar (but less sophisticated compared) to
5+
Kermut: Composite kernel regression for protein variant effects
6+
Peter Mørch Groth, Mads Herbert Kerrn, Lars Olsen, Jesper Salomon, Wouter Boomsma
7+
2024, 38th Conference on Neural Information Processing Systems (NeurIPS 2024).
8+
TL;DR: Gaussian process regression model with a novel composite kernel, Kermut, achieves
9+
state-of-the-art variant effect prediction while providing meaningful uncertainties.
10+
https://openreview.net/forum?id=jM9atrvUii
11+
"""
12+
13+
114
import pandas as pd
215
import numpy as np
316
from sklearn.model_selection import train_test_split
@@ -10,12 +23,11 @@
1023

1124
from tqdm import tqdm
1225

13-
from pypef.llm.prosst_lora_tune import (
14-
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied,
15-
prosst_tokenize_sequences, prosst_train
26+
from pypef.plm.prosst_lora_tune import (
27+
get_prosst_models, get_structure_quantizied
1628
)
17-
from pypef.llm.inference import inference
18-
from pypef.utils.helpers import get_vram, get_device
29+
from pypef.plm.inference import tokenize_sequences, plm_inference
30+
from pypef.utils.helpers import get_device
1931

2032

2133
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -76,14 +88,16 @@ def extract_prosst_embeddings(
7688
structures = [structures] * len(sequences)
7789

7890
assert len(sequences) == len(structures), \
79-
"Number of sequences must match number of structures"
91+
(f"Number of sequences must match number of structures "
92+
f"{len(sequences)} != {len(structures)}")
8093

8194
for seq, struct in tqdm(zip(sequences, structures),
8295
total=len(sequences),
8396
desc="Embedding (ProSST)"):
8497
# Tokenize sequence
8598
tokenized = prosst_tokenizer(
8699
[seq],
100+
max_length=len(seq) + 2,
87101
return_tensors="pt",
88102
padding=False,
89103
truncation=False
@@ -124,29 +138,32 @@ def extract_prosst_embeddings(
124138
return X
125139

126140

141+
def gp_fit():
142+
pass
127143

128144

129145

146+
def gp_fit_sklearn():
147+
pass
148+
149+
130150

131151

132152
if __name__ == '__main__':
133-
wt_seq = list(read_fasta_biopython('example_data/blat_ecolx/blat_ecolx_wt_seq.fa').values())[0]
134-
pdb = 'example_data/blat_ecolx/BLAT_ECOLX.pdb'
153+
wt_seq = list(read_fasta_biopython('datasets/BLAT_ECOLX/blat_ecolx_wt.fasta').values())[0]
154+
pdb = 'datasets/BLAT_ECOLX/BLAT_ECOLX.pdb'
135155
device = get_device()
136156
print("Getting ProSST models")
137157
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
138158
prosst_vocab = prosst_tokenizer.get_vocab()
139159
prosst_base_model = prosst_base_model.to(device)
140160

141161
print(f"Getting structure tokens...")
142-
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
162+
wt_input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
143163
pdb, prosst_tokenizer, wt_seq, verbose=True
144164
)
145165

146-
147-
148-
149-
df = pd.read_csv('example_data/blat_ecolx/BLAT_ECOLX_Stiffler_2015.csv')
166+
df = pd.read_csv('datasets/BLAT_ECOLX/BLAT_ECOLX_Stiffler_2015.csv')
150167
sequences = df['mutated_sequence'].to_list()
151168
y = df['DMS_score'].to_list()
152169

@@ -156,10 +173,18 @@ def extract_prosst_embeddings(
156173
# --- Step 2: Extract ProSST embeddings ---
157174
print(structure_input_ids)
158175
print('np.shape(structure_input_ids):', np.shape(structure_input_ids))
159-
wt_structure_input_ids = structure_input_ids[0, 1:-1].tolist() # Remove CLS/EOS
160-
X_train = extract_prosst_embeddings(prosst_base_model, prosst_tokenizer, s_train, wt_structure_input_ids)
176+
wt_structure_input_ids = structure_input_ids
177+
178+
X_emb_train = extract_prosst_embeddings(prosst_base_model, prosst_tokenizer, s_train, wt_structure_input_ids[0, 1:-1].tolist()) # Remove CLS/EOS
179+
180+
x_train_2, prosst_attention_mask_2 = tokenize_sequences(s_train, prosst_tokenizer)
181+
assert len(prosst_attention_mask[0]) == len(prosst_attention_mask_2), f"{len(prosst_attention_mask[0])}\n !=\n {len(prosst_attention_mask_2)}"
182+
183+
X_emb_train_2 = plm_inference(x_train_2, wt_input_ids, prosst_attention_mask, prosst_base_model,
184+
extract_emb=True, wt_structure_input_ids=wt_structure_input_ids) #[0, 1:-1])
185+
161186
print("Embedding extraction done")
162-
print(np.shape(X_train))
187+
assert np.shape(X_emb_train) == np.shape(X_emb_train_2)
163188

164189
# --- Step 3: Fit Gaussian Process ---
165190
if USE_SCIKIT_LEARN:
@@ -168,12 +193,12 @@ def extract_prosst_embeddings(
168193

169194
kernel = 1.0 * RBF(length_scale=1.0) + WhiteKernel(noise_level=0.1)
170195
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, normalize_y=True)
171-
gpr.fit(X_train, y_train)
196+
gpr.fit(X_emb_train, y_train)
172197

173198
else: # GPYTORCH
174199
import gpytorch
175200

176-
X_train_t = torch.tensor(X_train, dtype=torch.float32).to(device)
201+
X_train_t = torch.tensor(X_emb_train, dtype=torch.float32).to(device)
177202
y_train_t = torch.tensor(y_train, dtype=torch.float32).to(device)
178203

179204
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
@@ -195,26 +220,31 @@ def extract_prosst_embeddings(
195220

196221
# --- Step 4: Extract ProSST embeddings for test sequences ---
197222
print("Extracting ProSST embeddings for test sequences...")
198-
X_test = extract_prosst_embeddings(
223+
X_emb_test = extract_prosst_embeddings(
199224
model=prosst_base_model,
200225
prosst_tokenizer=prosst_tokenizer,
201226
sequences=s_test,
202-
structures=wt_structure_input_ids, # still using same WT structure
227+
structures=wt_structure_input_ids[0, 1:-1].tolist(), # still using same WT structure
203228
device=device
204229
)
205-
print("Test embeddings shape:", X_test.shape)
230+
print("Test embeddings shape:", X_emb_test.shape)
231+
232+
x_test_2, prosst_attention_mask_2 = tokenize_sequences(s_test, prosst_tokenizer)
233+
X_emb_test_2 = plm_inference(x_test_2, wt_input_ids, prosst_attention_mask, prosst_base_model,
234+
extract_emb=True, wt_structure_input_ids=wt_structure_input_ids) #[0, 1:-1])
206235

207236
# --- Step 5: Predict with Gaussian Process ---
208-
if USE_SCIKIT_LEARN:
209-
y_mean, y_std = gpr.predict(X_test, return_std=True)
210-
else:
211-
X_test_t = torch.tensor(X_test, dtype=torch.float32).to(device)
212-
gp_model.eval()
213-
likelihood.eval()
214-
with torch.no_grad(), gpytorch.settings.fast_pred_var():
215-
pred = likelihood(gp_model(X_test_t))
216-
y_mean = pred.mean.cpu().numpy()
217-
lower, upper = pred.confidence_region() # optional 95% CI
218-
219-
print("Predicted fitness:", y_mean)
220-
print("Spearman correlation:", spearmanr(y_test, y_mean))
237+
for x_t in [torch.tensor(X_emb_test).to(device), X_emb_test_2]:
238+
if USE_SCIKIT_LEARN:
239+
y_mean, y_std = gpr.predict(x_t, return_std=True)
240+
else:
241+
#X_test_t = torch.tensor(x_t, dtype=torch.float32).to(device)
242+
gp_model.eval()
243+
likelihood.eval()
244+
with torch.no_grad(), gpytorch.settings.fast_pred_var():
245+
pred = likelihood(gp_model(x_t))
246+
y_mean = pred.mean.cpu().numpy()
247+
lower, upper = pred.confidence_region() # optional 95% CI
248+
249+
print("Predicted fitness:", y_mean)
250+
print("Spearman correlation:", spearmanr(y_test, y_mean)) # SignificanceResult(statistic=0.8617991109937434, pvalue=0.0)

pypef/hybrid/hybrid_model.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
4040
from pypef.plm.esm_lora_tune import get_esm_models
4141
from pypef.plm.prosst_lora_tune import get_prosst_models
42-
from pypef.plm.inference import esm_setup, prosst_setup, llm_tokenizer, inference
42+
from pypef.plm.inference import esm_setup, prosst_setup, tokenize_sequences, plm_inference
4343
from pypef.plm.utils import get_batches
4444

4545
# sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and
@@ -90,6 +90,7 @@ def __init__(
9090
self.x_train_llm = llm_model_input['esm1v']['x_llm']
9191
self.wt_input_ids = llm_model_input['esm1v']['wt_input_ids']
9292
self.llm_attention_mask = llm_model_input['esm1v']['llm_attention_mask']
93+
self.llm_tokenizer = llm_model_input['esm1v']['llm_tokenizer']
9394
elif len(list(llm_model_input.keys())) == 1 and list(llm_model_input.keys())[0] == 'prosst':
9495
self.llm_key = 'prosst'
9596
self.llm_base_model = llm_model_input['prosst']['llm_base_model']
@@ -102,6 +103,7 @@ def __init__(
102103
self.llm_attention_mask = llm_model_input['prosst']['llm_attention_mask']
103104
self.wt_input_ids = llm_model_input['prosst']['wt_input_ids']
104105
self.structure_input_ids = llm_model_input['prosst']['structure_input_ids']
106+
self.llm_tokenizer = llm_model_input['prosst']['llm_tokenizer']
105107
else:
106108
raise RuntimeError("LLM input model dictionary not supported. Currently supported "
107109
"models are 'esm1v' or 'prosst'")
@@ -661,7 +663,7 @@ def hybrid_prediction(
661663
verbose=verbose,
662664
device=self.device).detach().cpu().numpy()
663665
elif self.llm_key == 'esm1v':
664-
x_llm_b = torch.from_numpy(get_batches(x_llm, batch_size=1, dtype=int))
666+
#x_llm_b = torch.from_numpy(get_batches(x_llm, batch_size=1, dtype=int))
665667
y_llm = self.llm_inference_function(
666668
xs=x_llm,
667669
wt_input_ids=self.wt_input_ids,
@@ -1062,11 +1064,11 @@ def performance_ls_ts(
10621064
if llm is not None:
10631065
if llm.lower().startswith('esm'):
10641066
llm_dict = esm_setup(train_sequences)
1065-
x_llm_test = llm_tokenizer(llm_dict, test_sequences)
1067+
x_llm_test = tokenize_sequences(test_sequences, llm_dict['esm1v']['llm_tokenizer'])
10661068
elif llm.lower() == 'prosst':
10671069
llm_dict = prosst_setup(
10681070
wt_seq, pdb_file, sequences=train_sequences)
1069-
x_llm_test = llm_tokenizer(llm_dict, test_sequences)
1071+
x_llm_test = tokenize_sequences(test_sequences, llm_dict['prosst']['llm_tokenizer'])
10701072
else:
10711073
llm_dict = None
10721074
x_llm_test = None
@@ -1111,8 +1113,10 @@ def performance_ls_ts(
11111113
substitution_sep, threads, False
11121114
)
11131115
if model.llm_model_input is not None:
1114-
logger.info(f"Found hybrid model with LLM {list(model.llm_model_input.keys())[0]}...")
1115-
x_llm_test = llm_tokenizer(model.llm_model_input, test_sequences)
1116+
llm_ = list(model.llm_model_input.keys())[0]
1117+
tokenizer = model.llm_model_input[llm_]['llm_tokenizer']
1118+
logger.info(f"Found hybrid model with LLM {llm_}...")
1119+
x_llm_test = tokenize_sequences(test_sequences, tokenizer)
11161120
y_test_pred = model.hybrid_prediction(x_test, x_llm_test)
11171121
else:
11181122
y_test_pred = model.hybrid_prediction(x_test)
@@ -1145,11 +1149,23 @@ def performance_ls_ts(
11451149
else:
11461150
model_type = 'LLM'
11471151
if llm == 'esm':
1152+
llm_dict = esm_setup(test_sequences[0], test_sequences) # TODO: Improve wt_seq input workaround
11481153
logger.info("Zero-shot LLM inference on test set using ESM1v...")
1149-
y_test_pred = inference(test_sequences, llm)
1154+
y_test_pred = plm_inference(
1155+
xs = llm_dict['esm1v']['x_llm'],
1156+
wt_input_ids=llm_dict['esm1v']['wt_input_ids'],
1157+
model=llm_dict['esm1v']['llm_base_model']
1158+
)
11501159
elif llm == 'prosst':
1160+
llm_dict = prosst_setup(test_sequences[0], test_sequences) # TODO: Improve wt_seq input workaround
11511161
logger.info("Zero-shot LLM inference on test set using ProSST...")
1152-
y_test_pred = inference(test_sequences, llm, pdb_file=pdb_file, wt_seq=wt_seq)
1162+
y_test_pred = plm_inference(
1163+
xs = llm_dict['prosst']['x_llm'],
1164+
wt_input_ids=llm_dict['prosst']['wt_input_ids'],
1165+
model=llm_dict['prosst']['llm_base_model'],
1166+
wt_structure_input_ids=llm_dict['prosst']['wt_structure_input_ids']
1167+
1168+
)
11531169
else:
11541170
raise RuntimeError("Unknown --llm flag option.")
11551171
else:
@@ -1264,11 +1280,13 @@ def predict_ps(
12641280
variants, sequences, None, params_file,
12651281
threads=threads, verbose=False, substitution_sep=separator
12661282
)
1267-
if model.llm_key is None:
1283+
if model.llm_key is None: # TODO: Check llm_key
12681284
ys_pred = model.hybrid_prediction(x_test)
12691285
else:
12701286
sequences = [str(seq) for seq in test_sequences]
1271-
x_llm_test = llm_tokenizer(model.llm_model_input, sequences)
1287+
llm_ = list(model.llm_model_input.keys())[0]
1288+
tokenizer = model.llm_model_input[llm_]['llm_tokenizer']
1289+
x_llm_test = tokenize_sequences(sequences, tokenizer)
12721290
ys_pred = model.hybrid_prediction(np.asarray(x_test), np.asarray(x_llm_test))
12731291
for k, y in enumerate(ys_pred):
12741292
all_y_v_pred.append((ys_pred[k], variants[k]))
@@ -1294,11 +1312,11 @@ def predict_ps(
12941312
if llm == 'esm':
12951313
model_type = 'LLM_ESM1v'
12961314
logger.info("Zero-shot LLM inference on test set using ESM1v...")
1297-
ys_pred = inference(sequences, llm)
1315+
ys_pred = plm_inference(sequences, llm) # TODO
12981316
elif llm == 'prosst':
12991317
model_type = 'LLM_ProSST'
13001318
logger.info("Zero-shot LLM inference on test set using ProSST...")
1301-
ys_pred = inference(sequences, llm, pdb_file=pdb_file, wt_seq=wt_seq)
1319+
ys_pred = plm_inference(sequences, llm, pdb_file=pdb_file, wt_seq=wt_seq) # TODO
13021320
else:
13031321
if not model_type.startswith('Hybrid'): # statistical DCA model
13041322
xs, variants, _, _, x_wt, *_ = plmc_or_gremlin_encoding(
@@ -1315,7 +1333,9 @@ def predict_ps(
13151333
ys_pred = model.hybrid_prediction(xs)
13161334
else:
13171335
sequences = [str(seq) for seq in sequences]
1318-
xs_llm = llm_tokenizer(model.llm_model_input, sequences)
1336+
llm_ = list(model.llm_model_input.keys())[0]
1337+
tokenizer = model.llm_model_input[llm_]['llm_tokenizer']
1338+
xs_llm = tokenize_sequences(sequences, tokenizer)
13191339
ys_pred = model.hybrid_prediction(np.asarray(xs), np.asarray(xs_llm))
13201340
assert len(xs) == len(variants) == len(ys_pred)
13211341
y_v_pred = zip(ys_pred, variants)
@@ -1375,7 +1395,7 @@ def predict_directed_evolution(
13751395
if model.llm_model_input is None:
13761396
y_pred = model.hybrid_prediction(xs)
13771397
else:
1378-
x_llm = llm_tokenizer(model.llm_model_input,
1398+
x_llm = tokenize_sequences(model.llm_model_input,
13791399
variant_sequence, verbose=False)
13801400

13811401
y_pred = model.hybrid_prediction(

0 commit comments

Comments
 (0)