Skip to content

Commit 0e927dd

Browse files
committed
dev: add metric.py (different Spearman torch implementations)
1 parent 88919f6 commit 0e927dd

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
import torch
2+
import numpy as np
3+
from sklearn.gaussian_process import GaussianProcessRegressor
4+
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
5+
from tqdm import tqdm
6+
import pandas as pd
7+
from sklearn.model_selection import train_test_split
8+
from scipy.stats import spearmanr
9+
import gpytorch
10+
import torch.nn.functional as F
11+
import os
12+
from proteinmpnn.protein_mpnn_utils import ProteinMPNN # pip install proteinmpnn
13+
14+
import torch
15+
import gpytorch
16+
import torch.nn.functional as F
17+
18+
19+
"""
20+
pip install proteinmpnn
21+
22+
python -m proteinmpnn.protein_mpnn_run \
23+
--pdb-path "example_data/blat_ecolx/BLAT_ECOLX.pdb" \
24+
--save-score 1 \
25+
--conditional-probs-only 1 \
26+
--num-seq-per-target 10 \
27+
--batch-size 1 \
28+
--out-folder "pmpnn_out" \
29+
--seed 37
30+
"""
31+
32+
class HellingerRBFKernel(gpytorch.kernels.Kernel):
33+
has_lengthscale = True # GPyTorch handles log-lengthscale automatically
34+
35+
def __init__(self, **kwargs):
36+
super().__init__(**kwargs)
37+
# Amplitude/variance parameter
38+
self.register_parameter(
39+
name="raw_variance",
40+
parameter=torch.nn.Parameter(torch.tensor(0.0))
41+
)
42+
self.register_constraint("raw_variance", gpytorch.constraints.Positive())
43+
44+
@property
45+
def variance(self):
46+
return self.raw_variance_constraint.transform(self.raw_variance)
47+
48+
@variance.setter
49+
def variance(self, value):
50+
self._set_variance(value)
51+
52+
def _set_variance(self, value):
53+
# Properly set raw_variance via inverse transform
54+
self.raw_variance.data = self.raw_variance_constraint.inverse_transform(value)
55+
56+
def forward(self, x1, x2, **params):
57+
"""
58+
x1: [n1, d] (probabilities)
59+
x2: [n2, d]
60+
Returns: covariance matrix [n1, n2]
61+
"""
62+
# Ensure probabilities
63+
x1 = torch.clamp(x1, min=0)
64+
x2 = torch.clamp(x2, min=0)
65+
x1 = x1 / x1.sum(dim=1, keepdim=True)
66+
x2 = x2 / x2.sum(dim=1, keepdim=True)
67+
68+
# Hellinger distance
69+
x1_sqrt = torch.sqrt(x1)
70+
x2_sqrt = torch.sqrt(x2)
71+
diff2 = (x1_sqrt.unsqueeze(1) - x2_sqrt.unsqueeze(0))**2
72+
H2 = 0.5 * diff2.sum(dim=2) # [n1, n2]
73+
74+
# RBF-like kernel
75+
K = self.variance * torch.exp(-H2 / (2 * self.lengthscale ** 2))
76+
return K
77+
78+
79+
class GPModel(gpytorch.models.ExactGP):
80+
def __init__(self, train_x, train_y, likelihood, kernel):
81+
super().__init__(train_x, train_y, likelihood)
82+
self.mean_module = gpytorch.means.ZeroMean()
83+
self.covar_module = kernel # <- Kermut kernel
84+
85+
def forward(self, x):
86+
mean_x = self.mean_module(x)
87+
covar_x = self.covar_module(x, x)
88+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
89+
90+
91+
def get_probs_from_pmlnn_npz(npz_file="pmpnn_out/conditional_probs_only/BLAT_ECOLX.npz"):
92+
print(f"Getting PMPNN amino acid probs from NPZ file: {os.path.abspath(npz_file)}...")
93+
data = np.load(npz_file)
94+
data_dict = {k: data[k] for k in data.files}
95+
data.close()
96+
log_ps = data_dict["log_p"][..., :20]
97+
98+
mean_probs = np.mean(log_ps, axis=0)
99+
mean_probs = torch.tensor(mean_probs) # [L, 20]
100+
probs = F.softmax(mean_probs, dim=-1) # [L, 20]
101+
return probs
102+
103+
104+
def get_probs_from_mutations(mutations, probs=None):
105+
if probs is None:
106+
probs = get_probs_from_pmlnn_npz()
107+
x_list = []
108+
for m in mutations:
109+
pos = int(m[1:-1]) # Only single muts for now
110+
x_list.append(probs[pos - 1])
111+
return torch.stack(x_list)
112+
113+
# For now, running with CLI...
114+
# https://github.com/petergroth/kermut/blob/main/example_scripts/conditional_probabilities_single.sh
115+
"""
116+
python -m proteinmpnn.protein_mpnn_run \
117+
--pdb-path "example_data/blat_ecolx/BLAT_ECOLX.pdb" \
118+
--save-score 1 \
119+
--conditional-probs-only 1 \
120+
--num-seq-per-target 10 \
121+
--batch-size 1 \
122+
--out-folder "pmpnn_out" \
123+
--seed 37
124+
"""
125+
126+
127+
if __name__ == '__main__':
128+
data = np.load("pmpnn_out/conditional_probs_only/BLAT_ECOLX.npz")
129+
data_dict = {k: data[k] for k in data.files}
130+
data.close()
131+
132+
for k, v in data_dict.items():
133+
print(f"K:{k}\nv:{v}\n{np.shape(v)}\n\n")
134+
135+
136+
import matplotlib.pyplot as plt
137+
from matplotlib.cm import get_cmap
138+
from cycler import cycler
139+
cmap = get_cmap('rainbow')
140+
amino_acids = list('ACDEFGHIKLMNPQRSTVWY') # Excluded: X
141+
colors = [cmap(i / len(amino_acids)) for i in range(len(amino_acids))]
142+
plt.rcParams['axes.prop_cycle'] = cycler(color=colors)
143+
log_ps = data_dict["log_p"][..., :20]
144+
145+
mean_probs = np.mean(log_ps, axis=0)
146+
147+
print(np.shape(mean_probs))
148+
#plt.figure(figsize=(20,5))
149+
#plt.plot(mean_probs, label=amino_acids, linewidth=0.5)
150+
# Annotate above peaks
151+
#for pos, (aa, val) in enumerate(zip(top_aa, top_val)):
152+
# if pos % 1 == 0:
153+
# plt.text(pos, val + 0.02, aa, ha='center', va='bottom', fontsize=4, color='black')
154+
155+
#plt.legend(ncol=7)
156+
#plt.show()
157+
158+
mean_probs = torch.tensor(mean_probs) # [L, 20]
159+
probs = F.softmax(mean_probs, dim=-1) # [L, 20]
160+
161+
162+
df = pd.read_csv('example_data/blat_ecolx/BLAT_ECOLX_Stiffler_2015.csv')
163+
print(df.columns)
164+
mutants = df['mutant'].to_list()
165+
sequences = df['mutated_sequence'].to_list()
166+
y = df['DMS_score'].to_list()
167+
168+
m_train, m_test, s_train, s_test, y_train, y_test = train_test_split(
169+
mutants, sequences, y, test_size=0.33, random_state=42) # train_size=100, test_size=200,
170+
171+
X_train = get_probs_from_mutations(m_train) # shape [N, 20]
172+
173+
y_train = torch.tensor(y_train) # shape [N]
174+
175+
# Initialize kernel
176+
kernel = HellingerRBFKernel()
177+
kernel.lengthscale = torch.tensor(0.5) # optional manual override
178+
kernel.variance = torch.tensor(1.0)
179+
180+
# Compute covariance matrix
181+
K = kernel(probs, probs)
182+
print(K.shape) # [286, 286]
183+
184+
# Visualize
185+
#plt.figure(figsize=(8, 6))
186+
#plt.imshow(K.detach().numpy(), cmap='viridis')
187+
#plt.colorbar(label='Kernel value')
188+
#plt.title('Hellinger RBF Kernel Matrix')
189+
#plt.xlabel('Sequence position')
190+
#plt.ylabel('Sequence position')
191+
#plt.show()
192+
193+
194+
# Training
195+
likelihood = gpytorch.likelihoods.GaussianLikelihood()
196+
model = GPModel(X_train, y_train, likelihood, kernel)
197+
198+
model.train()
199+
likelihood.train()
200+
201+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
202+
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
203+
204+
for i in tqdm(range(100)):
205+
optimizer.zero_grad()
206+
output = model(X_train)
207+
loss = -mll(output, y_train)
208+
loss.backward()
209+
optimizer.step()
210+
211+
X_test = get_probs_from_mutations(m_test)
212+
model.eval()
213+
likelihood.eval()
214+
215+
with torch.no_grad():
216+
pred_train = likelihood(model(X_train))
217+
y_pred_train = pred_train.mean.cpu().numpy()
218+
219+
with torch.no_grad():
220+
pred_test = likelihood(model(X_test))
221+
y_pred_test = pred_test.mean.cpu().numpy()
222+
223+
from scipy.stats import spearmanr
224+
rho, p = spearmanr(y_train.numpy(), y_pred_train)
225+
print("Spearman rho TRAIN:", rho)
226+
print("p-value TRAIN:", p)
227+
rho, p = spearmanr(y_test, y_pred_test)
228+
print("Spearman rho:", rho)
229+
print("p-value:", p)

pypef/gaussian_process/metrics.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import torchsort
3+
4+
5+
6+
def spearmanr2(pred, target, **kw):
7+
"From https://github.com/teddykoker/torchsort/blob/main/README.md"
8+
pred = torchsort.soft_rank(pred, **kw)
9+
target = torchsort.soft_rank(target, **kw)
10+
pred = pred - pred.mean()
11+
pred = pred / pred.norm()
12+
target = target - target.mean()
13+
target = target / target.norm()
14+
return (pred * target).sum()
15+
16+
17+
18+
def soft_rank_approx(x, tau=1.0):
19+
"""
20+
A simple soft rank approximation using pairwise comparisons.
21+
Args:
22+
x: tensor of shape (..., n)
23+
tau: temperature (larger = softer, smaller = closer to true ranks)
24+
Returns:
25+
approx ranks same shape as x
26+
"""
27+
diff = x.unsqueeze(-1) - x.unsqueeze(-2)
28+
# pairwise sigmoid scores
29+
P = torch.sigmoid(diff / tau)
30+
# sum of how many values each element is less than
31+
r = P.sum(dim=-1) + 0.5 # +0.5 to approximate average rank
32+
return r
33+
34+
def spearman_soft(x, y, tau=1.0):
35+
rx = soft_rank_approx(x, tau)
36+
ry = soft_rank_approx(y, tau)
37+
38+
# center
39+
rxc = rx - rx.mean(-1, keepdim=True)
40+
ryc = ry - ry.mean(-1, keepdim=True)
41+
42+
# normalize
43+
rxn = rxc / (rxc.norm(dim=-1, keepdim=True) + 1e-8)
44+
ryn = ryc / (ryc.norm(dim=-1, keepdim=True) + 1e-8)
45+
46+
return (rxn * ryn).sum(dim=-1)
47+
48+
49+
50+
51+
def spearman_corr_differentiable(pred: torch.Tensor, target: torch.Tensor,
52+
regularization_strength: float = 1.0,
53+
regularization: str = "l2"):
54+
"""
55+
REQUIRES TORCHSORT
56+
Compute a differentiable Spearman correlation coefficient between pred and target.
57+
Works on [batch_size, n] tensors; preserves gradients for backprop.
58+
"""
59+
# Soft ranks
60+
pred_rank = torchsort.soft_rank(pred, regularization="l2", regularization_strength=regularization_strength)
61+
target_rank = torchsort.soft_rank(target, regularization="l2", regularization_strength=regularization_strength)
62+
63+
# Center and normalize
64+
pred_rank = pred_rank - pred_rank.mean(dim=-1, keepdim=True)
65+
pred_rank = pred_rank / (pred_rank.norm(dim=-1, keepdim=True) + 1e-8)
66+
target_rank = target_rank - target_rank.mean(dim=-1, keepdim=True)
67+
target_rank = target_rank / (target_rank.norm(dim=-1, keepdim=True) + 1e-8)
68+
69+
# Spearman = dot product of normalized ranks
70+
return (pred_rank * target_rank).sum(dim=-1)

0 commit comments

Comments
 (0)