Skip to content

Commit 88919f6

Browse files
committed
dev: add first GP scripts
1 parent 663a163 commit 88919f6

File tree

8 files changed

+648
-30
lines changed

8 files changed

+648
-30
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import Literal, Tuple
2+
from sklearn.model_selection import train_test_split
3+
import torch
4+
5+
from gpytorch.kernels import ScaleKernel
6+
import gpytorch
7+
import pandas as pd
8+
from tqdm import tqdm
9+
10+
from gp_esm2_test import extract_esm_embeddings
11+
from gp_pmpnn_test import HellingerRBFKernel, get_probs_from_mutations
12+
from gp_prosst_test import (extract_prosst_embeddings, get_prosst_models,
13+
get_structure_quantizied, read_fasta_biopython)
14+
from metrics import spearman_soft, spearman_corr_differentiable, spearmanr2
15+
16+
class CombinedKernel(gpytorch.kernels.Kernel):
17+
"""
18+
Combine two kernels: K_seq + K_struct
19+
Input X is a single concatenated tensor: [seq | struct]
20+
"""
21+
22+
def __init__(self, kernel_seq, kernel_struct, d_seq):
23+
super().__init__()
24+
self.kernel_seq = kernel_seq
25+
self.kernel_struct = kernel_struct
26+
self.d_seq = d_seq # number of sequence dimensions
27+
28+
def forward(self, X1, X2, **params):
29+
X1_seq, X1_struct = X1[:, :self.d_seq], X1[:, self.d_seq:]
30+
X2_seq, X2_struct = X2[:, :self.d_seq], X2[:, self.d_seq:]
31+
32+
K_seq = self.kernel_seq(X1_seq, X2_seq)
33+
K_struct = self.kernel_struct(X1_struct, X2_struct)
34+
35+
return K_seq + K_struct # could also use product or weighted sum
36+
37+
38+
class MultiInputGP(gpytorch.models.ExactGP):
39+
def __init__(self, train_x, train_y, likelihood, kernel):
40+
super().__init__(train_x, train_y, likelihood)
41+
self.mean_module = gpytorch.means.ZeroMean()
42+
self.covar_module = kernel
43+
44+
def forward(self, X):
45+
mean_x = self.mean_module(X)
46+
covar_x = self.covar_module(X, X)
47+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
48+
49+
50+
51+
52+
53+
54+
# -----------------------------
55+
# Load and preprocess data
56+
# -----------------------------
57+
df = pd.read_csv('example_data/blat_ecolx/BLAT_ECOLX_Stiffler_2015.csv')
58+
59+
print(df.columns)
60+
mutants = df['mutant'].to_list()
61+
sequences = df['mutated_sequence'].to_list()
62+
y = df['DMS_score'].to_list()
63+
64+
m_train, m_test, s_train, s_test, y_train, y_test = train_test_split(
65+
mutants, sequences, y, test_size=0.33, random_state=42
66+
)
67+
68+
X_struct = get_probs_from_mutations(m_train) # [N, 20]
69+
70+
71+
print("Getting ProSST models")
72+
pdb = 'example_data/blat_ecolx/BLAT_ECOLX.pdb'
73+
wt_seq = list(read_fasta_biopython('example_data/blat_ecolx/blat_ecolx_wt_seq.fa').values())[0]
74+
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
75+
prosst_vocab = prosst_tokenizer.get_vocab()
76+
prosst_base_model = prosst_base_model.to("cuda")
77+
78+
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
79+
pdb, prosst_tokenizer, wt_seq, verbose=True
80+
)
81+
wt_structure_input_ids = structure_input_ids[0, 1:-1].tolist() # Remove CLS/EOS
82+
#X_seq = torch.tensor(extract_esm_embeddings(s_train)).float() # [N, d_seq]
83+
X_seq = torch.tensor(extract_prosst_embeddings(
84+
prosst_base_model, prosst_tokenizer, s_train, wt_structure_input_ids
85+
))
86+
y_train = torch.tensor(y_train).float()
87+
y_test = torch.tensor(y_test).float()
88+
89+
# Concatenate features
90+
X_combined = torch.cat([X_seq, X_struct], dim=-1) # Concenation is necessary as GPkernel does not accept a tuple as input
91+
d_seq = X_seq.shape[1]
92+
93+
# -----------------------------
94+
# Define kernels and model
95+
# -----------------------------
96+
seq_kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
97+
struct_kernel = HellingerRBFKernel()
98+
combined_kernel = CombinedKernel(seq_kernel, struct_kernel, d_seq=d_seq)
99+
100+
likelihood = gpytorch.likelihoods.GaussianLikelihood()
101+
model = MultiInputGP(X_combined, y_train, likelihood, combined_kernel)
102+
103+
# -----------------------------
104+
# Train
105+
# -----------------------------
106+
model.train()
107+
likelihood.train()
108+
109+
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
110+
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
111+
112+
pbar = tqdm(range(100), desc='Training')
113+
for i in pbar:
114+
optimizer.zero_grad()
115+
output = model(X_combined)
116+
loss = -mll(output, y_train)
117+
loss.backward()
118+
optimizer.step()
119+
pbar.set_description(f"Training (loss: {loss:.4f})")
120+
121+
# -----------------------------
122+
# Test
123+
# -----------------------------
124+
X_struct_test = get_probs_from_mutations(m_test)
125+
#X_seq_test = torch.tensor(extract_esm_embeddings(s_test)).float()
126+
X_seq_test = torch.tensor(extract_prosst_embeddings(prosst_base_model, prosst_tokenizer, s_test, wt_structure_input_ids))
127+
X_test_combined = torch.cat([X_seq_test, X_struct_test], dim=-1)
128+
129+
model.eval()
130+
likelihood.eval()
131+
132+
133+
with torch.no_grad(), gpytorch.settings.fast_pred_var():
134+
pred_train = likelihood(model(X_combined))
135+
y_pred_train = pred_train.mean.cpu().numpy()
136+
137+
pred = likelihood(model(X_test_combined))
138+
y_pred = pred.mean.cpu().numpy()
139+
140+
141+
from scipy.stats import spearmanr
142+
143+
rho, p = spearmanr(y_train, y_pred_train)
144+
print("Spearman rho SciPy TRAIN:", rho)
145+
print("Spearman soft TRAIN:", spearman_soft(y_train, torch.from_numpy(y_pred_train)).item())
146+
y_train_t = y_train.float().unsqueeze(0) # shape (1, n)
147+
y_pred_train_t = torch.from_numpy(y_pred_train).float().unsqueeze(0) # shape (1, n)
148+
print("Spearman corr diff (ChatGPT) TRAIN:", spearman_corr_differentiable(y_train_t, y_pred_train_t).item())
149+
print("Spearman2 torchsort TRAIN:", spearmanr2(y_train_t, y_pred_train_t).item())
150+
151+
rho, p = spearmanr(y_test, y_pred)
152+
print("Spearman rho SciPy TEST:", rho)
153+
print("Spearman soft TEST:", spearman_soft(y_test, torch.from_numpy(y_pred)).item())
154+
y_test_t = y_test.float().unsqueeze(0) # shape (1, n)
155+
y_pred_t = torch.from_numpy(y_pred).float().unsqueeze(0) # shape (1, n)
156+
print("Spearman corr diff (ChatGPT) TEST:", spearman_corr_differentiable(y_test_t, y_pred_t).item())
157+
print("Spearman2 torchsort TEST:", spearmanr2(y_test_t, y_pred_t).item())
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import torch
2+
import numpy as np
3+
from sklearn.gaussian_process import GaussianProcessRegressor
4+
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
5+
from tqdm import tqdm
6+
import pandas as pd
7+
from sklearn.model_selection import train_test_split
8+
from scipy.stats import spearmanr
9+
import gpytorch
10+
11+
# --- Step 1: Load a pretrained ESM plm_model ---
12+
from esm import pretrained # pip install fair-esm
13+
14+
15+
"""
16+
git clone https://github.com/facebookresearch/esm.git
17+
cd esm
18+
pip install .
19+
"""
20+
21+
22+
device = "cuda" if torch.cuda.is_available() else "cpu"
23+
24+
USE_SCIKIT_LEARN = False
25+
26+
27+
class ExactGPModel(gpytorch.models.ExactGP):
28+
def __init__(self, train_x, train_y, likelihood):
29+
super().__init__(train_x, train_y, likelihood)
30+
self.mean_module = gpytorch.means.ConstantMean()
31+
self.covar_module = gpytorch.kernels.ScaleKernel(
32+
gpytorch.kernels.RBFKernel()
33+
)
34+
35+
def forward(self, x):
36+
mean_x = self.mean_module(x)
37+
covar_x = self.covar_module(x)
38+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
39+
40+
41+
def extract_esm_embeddings(sequences):
42+
embeddings = []
43+
44+
for seq in tqdm(sequences, 'Embedding (ESM)'):
45+
data = [("protein", seq)]
46+
batch_labels, batch_strs, batch_tokens = batch_converter(data)
47+
batch_tokens = batch_tokens.to(device)
48+
with torch.no_grad():
49+
results = plm_model(batch_tokens, repr_layers=[33], return_contacts=False)
50+
token_representations = results["representations"][33]
51+
# Mean-pool per-residue representations (excluding special tokens)
52+
seq_embedding = token_representations[0, 1:len(seq)+1].mean(0)
53+
embeddings.append(seq_embedding.cpu().numpy())
54+
55+
X = np.vstack(embeddings)
56+
57+
return X
58+
59+
60+
plm_model, alphabet = pretrained.esm2_t33_650M_UR50D()
61+
plm_model = plm_model.to(device)
62+
batch_converter = alphabet.get_batch_converter()
63+
plm_model.eval() # disable dropout
64+
65+
if __name__ == '__main__':
66+
# Load ESM-2 (you can choose different sizes: 35M, 150M, 650M, 3B)
67+
68+
# --- Example dataset ---
69+
# sequences: list of amino acid strings
70+
# y: list/array of experimental fitness values
71+
72+
df = pd.read_csv('example_data/blat_ecolx/BLAT_ECOLX_Stiffler_2015.csv')
73+
sequences = df['mutated_sequence'].to_list()
74+
y = df['DMS_score'].to_list()
75+
76+
s_train, s_test, y_train, y_test = train_test_split(
77+
sequences, y, test_size=0.33, random_state=42) # train_size=100, test_size=200,
78+
79+
# --- Step 2: Extract ESM embeddings ---
80+
X = extract_esm_embeddings(s_train)
81+
print("Embedding extraction done")
82+
print(np.shape(X))
83+
84+
# --- Step 3: Build and fit a Gaussian Process ---
85+
86+
if USE_SCIKIT_LEARN:
87+
# RBF kernel + WhiteKernel (noise term)
88+
kernel = 1.0 * RBF(length_scale=1.0) + WhiteKernel(noise_level=0.1)
89+
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, normalize_y=True)
90+
gpr.fit(X, y_train)
91+
92+
else: # GPYTORCH
93+
# Likelihood
94+
# Suppose X: [num_sequences, embedding_dim], y: [num_sequences]
95+
X = torch.tensor(X, dtype=torch.float32).to(device)
96+
y_train = torch.tensor(y_train, dtype=torch.float32).to(device)
97+
98+
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
99+
gp_model = ExactGPModel(X, y_train, likelihood).to(device)
100+
101+
gp_model.train()
102+
likelihood.train()
103+
104+
optimizer = torch.optim.Adam(gp_model.parameters(), lr=0.1)
105+
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gp_model)
106+
107+
training_iter = 100
108+
for i in range(training_iter):
109+
optimizer.zero_grad()
110+
output = gp_model(X)
111+
loss = -mll(output, y_train)
112+
loss.backward()
113+
print(f"Iter {i+1}/{training_iter} - Loss: {loss.item():.3f}")
114+
optimizer.step()
115+
116+
# --- Step 4: Predict on new sequences ---
117+
test_embeddings = []
118+
119+
for seq in tqdm(s_test):
120+
data = [("protein", seq)]
121+
batch_labels, batch_strs, batch_tokens = batch_converter(data)
122+
batch_tokens = batch_tokens.to(device)
123+
with torch.no_grad():
124+
results = plm_model(batch_tokens, repr_layers=[33], return_contacts=False)
125+
seq_embedding = results["representations"][33][0, 1:len(seq)+1].mean(0)
126+
test_embeddings.append(seq_embedding.cpu().numpy())
127+
128+
X_test = np.array(test_embeddings) # or np.vstack
129+
print("Test embeddings shape:", X_test.shape)
130+
131+
if USE_SCIKIT_LEARN:
132+
y_mean, y_std = gpr.predict(X_test, return_std=True)
133+
else: # GPYTORCH
134+
X_test = torch.tensor(X_test, dtype=torch.float32).to(device)
135+
gp_model.eval()
136+
likelihood.eval()
137+
138+
with torch.no_grad(), gpytorch.settings.fast_pred_var():
139+
# Suppose X_test: [num_test, embedding_dim]
140+
X_test = torch.tensor(test_embeddings, dtype=torch.float32).to(device)
141+
pred = likelihood(gp_model(X_test))
142+
y_mean = pred.mean # predicted mean
143+
lower, upper = pred.confidence_region() # 95% confidence interval
144+
y_mean = y_mean.cpu().numpy()
145+
146+
print("Predicted fitness:", y_mean)
147+
#print("Uncertainty (std):", y_std)
148+
149+
print(spearmanr(y_test, y_mean))

pypef/gaussian_process/gp_opt.py

Whitespace-only changes.

0 commit comments

Comments
 (0)