Skip to content

Commit 9d07349

Browse files
committed
Add combined corr. loss function
1 parent 4c5986a commit 9d07349

File tree

7 files changed

+136
-105
lines changed

7 files changed

+136
-105
lines changed

pypef/gaussian_process/composite.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from pypef.gaussian_process.gp_pmpnn_test import HellingerRBFKernel, get_probs_from_mutations
1212
from pypef.gaussian_process.gp_prosst_test import (extract_prosst_embeddings, get_prosst_models,
1313
get_structure_quantizied, read_fasta_biopython)
14-
from pypef.gaussian_process.metrics import spearman_soft, spearman_corr_differentiable, spearmanr2
14+
from pypef.plm.utils import spearman_soft, correlation_loss, hybrid_corr_mse_loss, pearson_loss
15+
1516

1617
class CombinedKernel(gpytorch.kernels.Kernel):
1718
"""
@@ -47,30 +48,26 @@ def forward(self, X):
4748
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
4849

4950

50-
51-
52-
53-
5451
# -----------------------------
5552
# Load and preprocess data
5653
# -----------------------------
57-
df = pd.read_csv('example_data/blat_ecolx/BLAT_ECOLX_Stiffler_2015.csv')
54+
df = pd.read_csv('datasets/BLAT_ECOLX/BLAT_ECOLX_Stiffler_2015.csv')
5855

5956
print(df.columns)
6057
mutants = df['mutant'].to_list()
6158
sequences = df['mutated_sequence'].to_list()
6259
y = df['DMS_score'].to_list()
6360

6461
m_train, m_test, s_train, s_test, y_train, y_test = train_test_split(
65-
mutants, sequences, y, test_size=0.33, random_state=42
62+
mutants, sequences, y, train_size=100, test_size=100, random_state=42
6663
)
6764

68-
X_struct = get_probs_from_mutations(m_train) # [N, 20]
65+
#X_struct = get_probs_from_mutations(m_train) # [N, 20]
6966

7067

7168
print("Getting ProSST models")
72-
pdb = 'example_data/blat_ecolx/BLAT_ECOLX.pdb'
73-
wt_seq = list(read_fasta_biopython('example_data/blat_ecolx/blat_ecolx_wt_seq.fa').values())[0]
69+
pdb = 'datasets/BLAT_ECOLX/BLAT_ECOLX.pdb'
70+
wt_seq = list(read_fasta_biopython('datasets/BLAT_ECOLX/blat_ecolx_wt.fasta').values())[0]
7471
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
7572
prosst_vocab = prosst_tokenizer.get_vocab()
7673
prosst_base_model = prosst_base_model.to("cuda")
@@ -87,7 +84,7 @@ def forward(self, X):
8784
y_test = torch.tensor(y_test).float()
8885

8986
# Concatenate features
90-
X_combined = torch.cat([X_seq, X_struct], dim=-1) # Concenation is necessary as GPkernel does not accept a tuple as input
87+
X_combined = torch.cat([X_seq, X_seq], dim=-1) # Concenation is necessary as GPkernel does not accept a tuple as input
9188
d_seq = X_seq.shape[1]
9289

9390
# -----------------------------
@@ -121,10 +118,10 @@ def forward(self, X):
121118
# -----------------------------
122119
# Test
123120
# -----------------------------
124-
X_struct_test = get_probs_from_mutations(m_test)
121+
#X_struct_test = get_probs_from_mutations(m_test)
125122
#X_seq_test = torch.tensor(extract_esm_embeddings(s_test)).float()
126123
X_seq_test = torch.tensor(extract_prosst_embeddings(prosst_base_model, prosst_tokenizer, s_test, wt_structure_input_ids))
127-
X_test_combined = torch.cat([X_seq_test, X_struct_test], dim=-1)
124+
X_test_combined = torch.cat([X_seq_test, X_seq_test], dim=-1)
128125

129126
model.eval()
130127
likelihood.eval()
@@ -143,15 +140,23 @@ def forward(self, X):
143140
rho, p = spearmanr(y_train, y_pred_train)
144141
print("Spearman rho SciPy TRAIN:", rho)
145142
print("Spearman soft TRAIN:", spearman_soft(y_train, torch.from_numpy(y_pred_train)).item())
143+
print("Correlation loss Spearman TRAIN:", correlation_loss(y_train, torch.from_numpy(y_pred_train), method="spearman"))
144+
print("Correlation hybrid MSE loss Spearman TRAIN:", hybrid_corr_mse_loss(y_train, torch.from_numpy(y_pred_train)))
145+
print("Correlation loss Pearson TRAIN:", correlation_loss(y_train, torch.from_numpy(y_pred_train), method="pearson"))
146+
print("Correlation loss Pearson 2 TRAIN:", pearson_loss(y_train, torch.from_numpy(y_pred_train)))
146147
y_train_t = y_train.float().unsqueeze(0) # shape (1, n)
147148
y_pred_train_t = torch.from_numpy(y_pred_train).float().unsqueeze(0) # shape (1, n)
148-
print("Spearman corr diff (ChatGPT) TRAIN:", spearman_corr_differentiable(y_train_t, y_pred_train_t).item())
149-
print("Spearman2 torchsort TRAIN:", spearmanr2(y_train_t, y_pred_train_t).item())
149+
#print("Spearman corr diff (ChatGPT) TRAIN:", spearman_corr_differentiable(y_train_t, y_pred_train_t).item())
150+
#print("Spearman2 torchsort TRAIN:", spearmanr2(y_train_t, y_pred_train_t).item())
150151

151152
rho, p = spearmanr(y_test, y_pred)
152153
print("Spearman rho SciPy TEST:", rho)
153154
print("Spearman soft TEST:", spearman_soft(y_test, torch.from_numpy(y_pred)).item())
155+
print("Correlation loss Spearman TEST:", correlation_loss(y_test, torch.from_numpy(y_pred), method="spearman"))
156+
print("Correlation hybrid MSE loss Spearman TEST:", hybrid_corr_mse_loss(y_test, torch.from_numpy(y_pred)))
157+
print("Correlation loss Pearson TEST:", correlation_loss(y_test, torch.from_numpy(y_pred), method="pearson"))
158+
print("Correlation loss Pearson 2 TEST:", pearson_loss(y_test, torch.from_numpy(y_pred)))
154159
y_test_t = y_test.float().unsqueeze(0) # shape (1, n)
155160
y_pred_t = torch.from_numpy(y_pred).float().unsqueeze(0) # shape (1, n)
156-
print("Spearman corr diff (ChatGPT) TEST:", spearman_corr_differentiable(y_test_t, y_pred_t).item())
157-
print("Spearman2 torchsort TEST:", spearmanr2(y_test_t, y_pred_t).item())
161+
#print("Spearman corr diff (ChatGPT) TEST:", spearman_corr_differentiable(y_test_t, y_pred_t).item())
162+
#print("Spearman2 torchsort TEST:", spearmanr2(y_test_t, y_pred_t).item())

pypef/gaussian_process/gp_opt.py

Whitespace-only changes.

pypef/gaussian_process/gp_prosst_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,11 @@ def extract_prosst_embeddings(
139139

140140

141141
def gp_fit():
142-
pass
143-
142+
pass # TODO
144143

145144

146145
def gp_fit_sklearn():
147-
pass
148-
149-
146+
pass # TODO
150147

151148

152149
if __name__ == '__main__':

pypef/gaussian_process/metrics.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

pypef/plm/inference.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pypef.plm.prosst_lora_tune import get_prosst_models, get_structure_quantizied
1717
from pypef.utils.helpers import get_device
18-
from pypef.plm.utils import corr_loss, get_batches
18+
from pypef.plm.utils import pearson_loss, get_batches
1919
from pypef.plm.esm_lora_tune import get_esm_models
2020

2121

@@ -162,12 +162,12 @@ def sequence_log_likelihood(
162162
output_hidden_states=extract_emb,
163163
**model_kwargs
164164
)
165-
166-
token_embeddings = outputs.hidden_states[-1] # (1, L+2, D)
167-
# Mean pool over residues (exclude CLS/EOS)
168-
seq_embedding = token_embeddings[0, 1:-1].mean(dim=0)
169-
embeddings.append(seq_embedding)
170-
continue
165+
if extract_emb:
166+
token_embeddings = outputs.hidden_states[-1] # (1, L+2, D)
167+
# Mean pool over residues (exclude CLS/EOS)
168+
seq_embedding = token_embeddings[0, 1:-1].mean(dim=0)
169+
embeddings.append(seq_embedding)
170+
continue
171171

172172
except TypeError as e:
173173
logger.info(f"Did not find model input keyword arguments (kwargs: "
@@ -633,7 +633,7 @@ def esm_setup(wt_seq, sequences, device: str | None = None, verbose: bool = True
633633
'llm_optimizer': esm_optimizer,
634634
'llm_train_function': plm_train,
635635
'llm_inference_function': plm_inference,
636-
'llm_loss_function': corr_loss,
636+
'llm_loss_function': pearson_loss,
637637
'x_llm' : torch.tensor(x_esm), # TODO: Not needed here?
638638
'llm_attention_mask': torch.tensor(esm_attention_mask), # TODO: Not needed here?
639639
'wt_input_ids': torch.tensor(wt_tokens), # TODO: Not needed here?
@@ -680,7 +680,7 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None, verbose
680680
'llm_optimizer': prosst_optimizer,
681681
'llm_train_function': plm_train,
682682
'llm_inference_function': plm_inference, # prosst_infer,
683-
'llm_loss_function': corr_loss,
683+
'llm_loss_function': pearson_loss,
684684
'x_llm' : x_llm_train_prosst,
685685
'llm_attention_mask': prosst_attention_mask,
686686
'llm_vocab': prosst_vocab,

pypef/plm/utils.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,70 @@
1313
logger = logging.getLogger('pypef.llm.utils')
1414

1515

16-
def corr_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
16+
def hybrid_corr_mse_loss(y_true, y_pred, tau=0.1, alpha=0.5):
17+
"""
18+
Hybrid differentiable loss combining Spearman correlation and MSE.
19+
"""
20+
# Differentiable Spearman
21+
loss_rank = correlation_loss(y_true, y_pred, method="spearman", tau=tau)
22+
# MSE
23+
loss_value = torch.mean((y_pred - y_true)**2)
24+
# Combine
25+
return alpha * loss_rank + (1 - alpha) * loss_value
26+
27+
28+
29+
def correlation_loss(y_true: torch.Tensor,
30+
y_pred: torch.Tensor,
31+
method: str = "spearman",
32+
tau: float = 0.1) -> torch.Tensor:
33+
"""
34+
Differentiable correlation loss for PyTorch.
35+
36+
Args:
37+
y_true: Tensor of shape (..., n) or (batch, n)
38+
y_pred: Tensor of same shape as y_true
39+
method: "pearson" or "spearman"
40+
tau: temperature for soft-rank approximation (used if method="spearman")
41+
42+
Returns:
43+
Scalar tensor representing the loss (to minimize)
44+
"""
45+
if method == "spearman":
46+
# Soft rank approximation
47+
x = y_true
48+
y = y_pred
49+
50+
def soft_rank(x, tau):
51+
x = x.unsqueeze(-1)
52+
diff = x - x.transpose(-1, -2)
53+
P = torch.sigmoid(diff / tau)
54+
return P.sum(dim=-1) + 0.5
55+
56+
rx = soft_rank(x, tau)
57+
ry = soft_rank(y, tau)
58+
elif method == "pearson":
59+
rx = y_true
60+
ry = y_pred
61+
else:
62+
raise ValueError(f"Unsupported method: {method}. Choose 'pearson' or 'spearman'.")
63+
64+
# Centering
65+
rx_c = rx - rx.mean(dim=-1, keepdim=True)
66+
ry_c = ry - ry.mean(dim=-1, keepdim=True)
67+
68+
# Normalize (like dividing by std)
69+
rx_n = rx_c / (rx_c.norm(dim=-1, keepdim=True) + 1e-8)
70+
ry_n = ry_c / (ry_c.norm(dim=-1, keepdim=True) + 1e-8)
71+
72+
# Compute correlation
73+
corr = (rx_n * ry_n).sum(dim=-1)
74+
75+
# Return scalar loss (to minimize, so negative correlation)
76+
return -corr.mean()
77+
78+
79+
def pearson_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
1780
res_true = y_true - torch.mean(y_true)
1881
res_pred = y_pred - torch.mean(y_pred)
1982
cov = torch.mean(res_true * res_pred)
@@ -24,6 +87,41 @@ def corr_loss(y_true: torch.Tensor, y_pred: torch.Tensor):
2487
return - cov / (sigma_true * sigma_pred)
2588

2689

90+
def spearman_loss(y_true, y_pred, tau=0.1):
91+
"""Maximizing Spearman correlation"""
92+
return - spearman_soft(y_true, y_pred, tau=tau).mean()
93+
94+
95+
def soft_rank_approx(x, tau=1.0):
96+
"""
97+
A simple soft rank approximation using pairwise comparisons.
98+
Args:
99+
x: tensor of shape (..., n)
100+
tau: temperature (larger = softer, smaller = closer to true ranks)
101+
Returns:
102+
approx ranks same shape as x
103+
"""
104+
diff = x.unsqueeze(-1) - x.unsqueeze(-2)
105+
# pairwise sigmoid scores
106+
P = torch.sigmoid(diff / tau)
107+
# sum of how many values each element is less than
108+
r = P.sum(dim=-1) + 0.5 # +0.5 to approximate average rank
109+
return r
110+
111+
112+
def spearman_soft(x, y, tau=0.1):
113+
rx = soft_rank_approx(x, tau)
114+
ry = soft_rank_approx(y, tau)
115+
116+
rxc = rx - rx.mean(dim=-1, keepdim=True)
117+
ryc = ry - ry.mean(dim=-1, keepdim=True)
118+
119+
rxn = rxc / (rxc.norm(dim=-1, keepdim=True) + 1e-8)
120+
ryn = ryc / (ryc.norm(dim=-1, keepdim=True) + 1e-8)
121+
122+
return (rxn * ryn).sum(dim=-1)
123+
124+
27125
def get_batches(a, dtype, batch_size=5,
28126
keep_remaining=False, verbose: bool = False
29127
) -> list | list[np.ndarray]:
@@ -80,6 +178,7 @@ def is_model_cached(repo_id: str, cache_dir: str):
80178
Check if the required model and tokenizer files are cached locally.
81179
"""
82180
snapshot_dir = None
181+
ref_file = None
83182
if os.path.isdir(cache_dir):
84183
ref_file = os.path.join(
85184
cache_dir, f'models--{repo_id.replace("/", "--")}', 'refs', 'main'

scripts/ProteinGym_runs/protgym_hybrid_perf_test_crossval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys # Use local directory PyPEF files
2121
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
2222
from pypef.dca.gremlin_inference import GREMLIN
23-
from pypef.plm.utils import get_batches, corr_loss
23+
from pypef.plm.utils import get_batches, pearson_loss
2424
from pypef.plm.esm_lora_tune import (
2525
get_esm_models, tokenize_sequences,
2626
esm_train, esm_infer
@@ -234,7 +234,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
234234
'llm_optimizer': esm_optimizer,
235235
'llm_train_function': esm_train,
236236
'llm_inference_function': esm_infer,
237-
'llm_loss_function': corr_loss,
237+
'llm_loss_function': pearson_loss,
238238
'x_llm' : x_llm_train_esm,
239239
'llm_attention_mask': esm_attention_mask
240240
}
@@ -246,7 +246,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
246246
'llm_optimizer': prosst_optimizer,
247247
'llm_train_function': prosst_train,
248248
'llm_inference_function': get_logits_from_full_seqs,
249-
'llm_loss_function': corr_loss,
249+
'llm_loss_function': pearson_loss,
250250
'x_llm' : x_llm_train_prosst,
251251
'llm_attention_mask': prosst_attention_mask,
252252
'input_ids': input_ids,

0 commit comments

Comments
 (0)