1+ import torch
2+ import numpy as np
3+ from sklearn .gaussian_process import GaussianProcessRegressor
4+ from sklearn .gaussian_process .kernels import RBF , WhiteKernel
5+ from tqdm import tqdm
6+ import pandas as pd
7+ from sklearn .model_selection import train_test_split
8+ from scipy .stats import spearmanr
9+ import gpytorch
10+ import torch .nn .functional as F
11+ import os
12+ from proteinmpnn .protein_mpnn_utils import ProteinMPNN # pip install proteinmpnn
13+
14+ import torch
15+ import gpytorch
16+ import torch .nn .functional as F
17+
18+
19+ """
20+ pip install proteinmpnn
21+
22+ python -m proteinmpnn.protein_mpnn_run \
23+ --pdb-path "example_data/blat_ecolx/BLAT_ECOLX.pdb" \
24+ --save-score 1 \
25+ --conditional-probs-only 1 \
26+ --num-seq-per-target 10 \
27+ --batch-size 1 \
28+ --out-folder "pmpnn_out" \
29+ --seed 37
30+ """
31+
32+ class HellingerRBFKernel (gpytorch .kernels .Kernel ):
33+ has_lengthscale = True # GPyTorch handles log-lengthscale automatically
34+
35+ def __init__ (self , ** kwargs ):
36+ super ().__init__ (** kwargs )
37+ # Amplitude/variance parameter
38+ self .register_parameter (
39+ name = "raw_variance" ,
40+ parameter = torch .nn .Parameter (torch .tensor (0.0 ))
41+ )
42+ self .register_constraint ("raw_variance" , gpytorch .constraints .Positive ())
43+
44+ @property
45+ def variance (self ):
46+ return self .raw_variance_constraint .transform (self .raw_variance )
47+
48+ @variance .setter
49+ def variance (self , value ):
50+ self ._set_variance (value )
51+
52+ def _set_variance (self , value ):
53+ # Properly set raw_variance via inverse transform
54+ self .raw_variance .data = self .raw_variance_constraint .inverse_transform (value )
55+
56+ def forward (self , x1 , x2 , ** params ):
57+ """
58+ x1: [n1, d] (probabilities)
59+ x2: [n2, d]
60+ Returns: covariance matrix [n1, n2]
61+ """
62+ # Ensure probabilities
63+ x1 = torch .clamp (x1 , min = 0 )
64+ x2 = torch .clamp (x2 , min = 0 )
65+ x1 = x1 / x1 .sum (dim = 1 , keepdim = True )
66+ x2 = x2 / x2 .sum (dim = 1 , keepdim = True )
67+
68+ # Hellinger distance
69+ x1_sqrt = torch .sqrt (x1 )
70+ x2_sqrt = torch .sqrt (x2 )
71+ diff2 = (x1_sqrt .unsqueeze (1 ) - x2_sqrt .unsqueeze (0 ))** 2
72+ H2 = 0.5 * diff2 .sum (dim = 2 ) # [n1, n2]
73+
74+ # RBF-like kernel
75+ K = self .variance * torch .exp (- H2 / (2 * self .lengthscale ** 2 ))
76+ return K
77+
78+
79+ class GPModel (gpytorch .models .ExactGP ):
80+ def __init__ (self , train_x , train_y , likelihood , kernel ):
81+ super ().__init__ (train_x , train_y , likelihood )
82+ self .mean_module = gpytorch .means .ZeroMean ()
83+ self .covar_module = kernel # <- Kermut kernel
84+
85+ def forward (self , x ):
86+ mean_x = self .mean_module (x )
87+ covar_x = self .covar_module (x , x )
88+ return gpytorch .distributions .MultivariateNormal (mean_x , covar_x )
89+
90+
91+ def get_probs_from_pmlnn_npz (npz_file = "pmpnn_out/conditional_probs_only/BLAT_ECOLX.npz" ):
92+ print (f"Getting PMPNN amino acid probs from NPZ file: { os .path .abspath (npz_file )} ..." )
93+ data = np .load (npz_file )
94+ data_dict = {k : data [k ] for k in data .files }
95+ data .close ()
96+ log_ps = data_dict ["log_p" ][..., :20 ]
97+
98+ mean_probs = np .mean (log_ps , axis = 0 )
99+ mean_probs = torch .tensor (mean_probs ) # [L, 20]
100+ probs = F .softmax (mean_probs , dim = - 1 ) # [L, 20]
101+ return probs
102+
103+
104+ def get_probs_from_mutations (mutations , probs = None ):
105+ if probs is None :
106+ probs = get_probs_from_pmlnn_npz ()
107+ x_list = []
108+ for m in mutations :
109+ pos = int (m [1 :- 1 ]) # Only single muts for now
110+ x_list .append (probs [pos - 1 ])
111+ return torch .stack (x_list )
112+
113+ # For now, running with CLI...
114+ # https://github.com/petergroth/kermut/blob/main/example_scripts/conditional_probabilities_single.sh
115+ """
116+ python -m proteinmpnn.protein_mpnn_run \
117+ --pdb-path "example_data/blat_ecolx/BLAT_ECOLX.pdb" \
118+ --save-score 1 \
119+ --conditional-probs-only 1 \
120+ --num-seq-per-target 10 \
121+ --batch-size 1 \
122+ --out-folder "pmpnn_out" \
123+ --seed 37
124+ """
125+
126+
127+ if __name__ == '__main__' :
128+ data = np .load ("pmpnn_out/conditional_probs_only/BLAT_ECOLX.npz" )
129+ data_dict = {k : data [k ] for k in data .files }
130+ data .close ()
131+
132+ for k , v in data_dict .items ():
133+ print (f"K:{ k } \n v:{ v } \n { np .shape (v )} \n \n " )
134+
135+
136+ import matplotlib .pyplot as plt
137+ from matplotlib .cm import get_cmap
138+ from cycler import cycler
139+ cmap = get_cmap ('rainbow' )
140+ amino_acids = list ('ACDEFGHIKLMNPQRSTVWY' ) # Excluded: X
141+ colors = [cmap (i / len (amino_acids )) for i in range (len (amino_acids ))]
142+ plt .rcParams ['axes.prop_cycle' ] = cycler (color = colors )
143+ log_ps = data_dict ["log_p" ][..., :20 ]
144+
145+ mean_probs = np .mean (log_ps , axis = 0 )
146+
147+ print (np .shape (mean_probs ))
148+ #plt.figure(figsize=(20,5))
149+ #plt.plot(mean_probs, label=amino_acids, linewidth=0.5)
150+ # Annotate above peaks
151+ #for pos, (aa, val) in enumerate(zip(top_aa, top_val)):
152+ # if pos % 1 == 0:
153+ # plt.text(pos, val + 0.02, aa, ha='center', va='bottom', fontsize=4, color='black')
154+
155+ #plt.legend(ncol=7)
156+ #plt.show()
157+
158+ mean_probs = torch .tensor (mean_probs ) # [L, 20]
159+ probs = F .softmax (mean_probs , dim = - 1 ) # [L, 20]
160+
161+
162+ df = pd .read_csv ('example_data/blat_ecolx/BLAT_ECOLX_Stiffler_2015.csv' )
163+ print (df .columns )
164+ mutants = df ['mutant' ].to_list ()
165+ sequences = df ['mutated_sequence' ].to_list ()
166+ y = df ['DMS_score' ].to_list ()
167+
168+ m_train , m_test , s_train , s_test , y_train , y_test = train_test_split (
169+ mutants , sequences , y , test_size = 0.33 , random_state = 42 ) # train_size=100, test_size=200,
170+
171+ X_train = get_probs_from_mutations (m_train ) # shape [N, 20]
172+
173+ y_train = torch .tensor (y_train ) # shape [N]
174+
175+ # Initialize kernel
176+ kernel = HellingerRBFKernel ()
177+ kernel .lengthscale = torch .tensor (0.5 ) # optional manual override
178+ kernel .variance = torch .tensor (1.0 )
179+
180+ # Compute covariance matrix
181+ K = kernel (probs , probs )
182+ print (K .shape ) # [286, 286]
183+
184+ # Visualize
185+ #plt.figure(figsize=(8, 6))
186+ #plt.imshow(K.detach().numpy(), cmap='viridis')
187+ #plt.colorbar(label='Kernel value')
188+ #plt.title('Hellinger RBF Kernel Matrix')
189+ #plt.xlabel('Sequence position')
190+ #plt.ylabel('Sequence position')
191+ #plt.show()
192+
193+
194+ # Training
195+ likelihood = gpytorch .likelihoods .GaussianLikelihood ()
196+ model = GPModel (X_train , y_train , likelihood , kernel )
197+
198+ model .train ()
199+ likelihood .train ()
200+
201+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
202+ mll = gpytorch .mlls .ExactMarginalLogLikelihood (likelihood , model )
203+
204+ for i in tqdm (range (100 )):
205+ optimizer .zero_grad ()
206+ output = model (X_train )
207+ loss = - mll (output , y_train )
208+ loss .backward ()
209+ optimizer .step ()
210+
211+ X_test = get_probs_from_mutations (m_test )
212+ model .eval ()
213+ likelihood .eval ()
214+
215+ with torch .no_grad ():
216+ pred_train = likelihood (model (X_train ))
217+ y_pred_train = pred_train .mean .cpu ().numpy ()
218+
219+ with torch .no_grad ():
220+ pred_test = likelihood (model (X_test ))
221+ y_pred_test = pred_test .mean .cpu ().numpy ()
222+
223+ from scipy .stats import spearmanr
224+ rho , p = spearmanr (y_train .numpy (), y_pred_train )
225+ print ("Spearman rho TRAIN:" , rho )
226+ print ("p-value TRAIN:" , p )
227+ rho , p = spearmanr (y_test , y_pred_test )
228+ print ("Spearman rho:" , rho )
229+ print ("p-value:" , p )
0 commit comments