Skip to content

Commit 36a89d3

Browse files
committed
New branch: dev
implementing/testing (hybrid) fine-tuned protein LLM models such as ESM1v
1 parent cd3d0a7 commit 36a89d3

File tree

2 files changed

+1041
-0
lines changed

2 files changed

+1041
-0
lines changed
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# Niklas Siedhoff, 17.01.2025
2+
# Inspired from ConFit
3+
# https://github.com/luo-group/ConFit
4+
5+
6+
import torch
7+
from torch.utils.data import Dataset, DataLoader
8+
import numpy as np
9+
from scipy import stats
10+
import pandas as pd
11+
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
12+
from peft.utils.other import fsdp_auto_wrap_policy
13+
from transformers import EsmForMaskedLM, EsmTokenizer, EsmConfig
14+
from scipy.stats import spearmanr
15+
import os
16+
import gc
17+
import matplotlib.pyplot as plt
18+
from matplotlib.colors import XKCD_COLORS
19+
from tqdm import tqdm
20+
from sklearn.metrics import mean_squared_error
21+
from sklearn.model_selection import train_test_split
22+
23+
24+
# Get cpu, gpu or mps device for training.
25+
device = (
26+
"cuda"
27+
if torch.cuda.is_available()
28+
else "mps"
29+
if torch.backends.mps.is_available()
30+
else "cpu"
31+
)
32+
#device="cpu"
33+
# https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L91 -->
34+
# https://github.com/facebookresearch/esm/blob/main/esm/constants.py#L7
35+
#proteinseq_toks = {
36+
# 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
37+
#}
38+
## self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
39+
proteinseq_toks = {
40+
'toks': ['<null_0>', '<pad>', '<eos>', '<unk>',
41+
'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-',
42+
'<cls>', '<mask>', '<sep>']
43+
}
44+
#print(len(proteinseq_toks['toks']))
45+
46+
"""
47+
48+
seqs_, attention_mask = tokenizer(
49+
seqs,
50+
padding='max_length',
51+
truncation=True,
52+
max_length=len(wt_seq)
53+
).values()
54+
print(seqs_)
55+
print(np.shape(seqs_), np.shape(attention_mask))
56+
print(muts)
57+
log_scores = []
58+
wt_score = 0.0
59+
with torch.no_grad():
60+
wt_sequence_distribution = []
61+
for i, (seq_, attn_msk) in enumerate(tqdm(zip(seqs_, attention_mask), total=len(seqs_))):
62+
out = model(torch.tensor([seq_]), torch.tensor([attn_msk]), output_hidden_states=True) # batch of 1
63+
logits = out.logits
64+
#print(i,'/',len(seqs_), logits.shape)
65+
log_probs = torch.log_softmax(logits, dim=-1)
66+
#print(log_probs, log_probs.shape)
67+
if i == 0:
68+
for i_pos, aa_pos in enumerate(log_probs[0]): # is of shape x,y,1 so one can simply use [0]
69+
#print(f'Seq. {i}, AA pos {i_pos+2}, log Probs.: {aa_pos}')
70+
canonical_positional_aa_distribution = []
71+
for i_aa_prob, aa_prob in enumerate(aa_pos):
72+
#print(f"AA {i_aa_prob} = {proteinseq_toks['toks'][i_aa_prob]} : {aa_prob}")
73+
if 4 <= i_aa_prob <= 23:
74+
canonical_positional_aa_distribution.append(aa_prob)
75+
wt_sequence_distribution.append(canonical_positional_aa_distribution)
76+
#print(np.shape(wt_sequence_distribution))
77+
if i == 0:
78+
wt_score = torch.sum(log_probs)
79+
else:
80+
log_scores.append(float(torch.sum(log_probs).cpu()))
81+
#print(torch.sum(log_probs))
82+
83+
print('WT score:', wt_score)
84+
df['predicted_log_score_ESM1v'] = log_scores
85+
df.to_csv('esm1_pred.csv')
86+
print(stats.spearmanr(scores, log_scores))
87+
print('np.shape(wt_sequence_distribution):',np.shape(wt_sequence_distribution))
88+
89+
var_y_trues = dict(zip(muts, scores))
90+
var_y_preds = dict(zip(muts, log_scores))
91+
92+
93+
fig, ax = plt.subplots(figsize=(30, 6))
94+
k = 0
95+
x_tick_poses, labels = [], []
96+
for i, aa_distr in enumerate(wt_sequence_distribution):
97+
x_tick_pos = np.array(range(len(aa_distr))) + k
98+
for aa in proteinseq_toks['toks'][4:24]:
99+
labels.append(f"{i+2}{aa}")
100+
if i == 0:
101+
plt.bar(x_tick_pos, aa_distr, label=proteinseq_toks['toks'][4:24], color=XKCD_COLORS)
102+
else:
103+
plt.bar(x_tick_pos, aa_distr, color=XKCD_COLORS)
104+
x_tick_poses.append(x_tick_pos)
105+
k += len(aa_distr) + 1
106+
plt.legend()
107+
plt.xticks(np.array(x_tick_poses).flatten(), labels, size=1, rotation=45)
108+
plt.margins(0.01)
109+
plt.tight_layout()
110+
plt.savefig('aa_esm1v_probability_distribution.png', dpi=300)
111+
112+
113+
#yt_vs_sorted, yt_fs_sorted, yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs = sort_var_fits(var_y_trues, var_y_preds)
114+
#get_loss(yp_fs_sorted_according_to_ytfs)
115+
"""
116+
117+
118+
def sort_var_fits(var_fits_true, var_fits_pred):
119+
120+
def get_kvs(d):
121+
variants, fitnesses = [], []
122+
for k, v in d.items():
123+
variants.append(k)
124+
fitnesses.append(v)
125+
return variants, fitnesses
126+
127+
yt_vs, yt_fs = get_kvs(var_fits_true)
128+
yp_vs, yp_fs = get_kvs(var_fits_pred)
129+
assert len(yt_vs) == len(yt_fs) == len(yp_vs) == len(yp_fs)
130+
if not yt_vs == yp_vs:
131+
yp_vs_temp, yp_fs_temp = [], []
132+
for vart, _yt in zip(yt_vs, yt_fs):
133+
for varp, yp in zip(yp_vs, yp_fs):
134+
if vart == varp:
135+
yp_vs_temp.append(vart)
136+
yp_fs_temp.append(yp)
137+
yp_vs, yp_fs = yp_vs_temp, yp_fs_temp
138+
assert yt_vs == yp_vs
139+
(
140+
yt_vs_sorted, yt_fs_sorted,
141+
yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs
142+
) = [list(l) for l in zip(*sorted(zip(yt_vs, yt_fs, yp_vs, yp_fs), key=lambda x: x[1]))]
143+
144+
assert yt_vs_sorted == yp_vs_sorted_according_to_ytfs
145+
146+
return yt_vs_sorted, yt_fs_sorted, yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs
147+
148+
149+
def get_encoded_seqs(sequences, tokenizer, max_length=104):
150+
encoded_sequences, attention_masks = tokenizer(
151+
sequences,
152+
padding='max_length',
153+
truncation=True,
154+
max_length=max_length
155+
).values()
156+
return encoded_sequences, attention_masks
157+
158+
def get_y_pred_scores(encoded_sequences, attention_masks, model):
159+
out = model(encoded_sequences.to(device), attention_masks.to(device), output_hidden_states=True)
160+
logits = out.logits
161+
token_probs = torch.log_softmax(logits, dim=-1)
162+
log_probs = []
163+
for i_s, sequence in enumerate(encoded_sequences):
164+
seq_log_probs = []
165+
for i_aa, aa in enumerate(sequence):
166+
#print('Target AA:', i_aa, aa, proteinseq_toks['toks'][aa], token_probs[i_s, i_aa, aa])
167+
if i_aa == 0:
168+
seq_log_probs = token_probs[i_s, i_aa, aa].reshape(1) # or better just use Tensor.index_select() function!
169+
#print(seq_log_probs)
170+
else:
171+
172+
seq_log_probs = torch.cat((seq_log_probs, token_probs[i_s, i_aa, aa].reshape(1)), 0)
173+
#seq_log_probs.append(token_probs[i_s, i_aa, aa])
174+
if i_s == 0:
175+
log_probs = torch.sum(torch.Tensor(seq_log_probs)).reshape(1)
176+
else:
177+
#print(i_s, log_probs2)
178+
log_probs = torch.cat((log_probs, torch.sum(torch.Tensor(seq_log_probs)).reshape(1)), 0)
179+
#print(log_probs2)
180+
return log_probs
181+
182+
183+
def _get_ranks(x: torch.Tensor) -> torch.Tensor:
184+
tmp = x.argsort()
185+
ranks = torch.zeros_like(tmp)
186+
ranks[tmp] = torch.arange(len(x))
187+
return ranks
188+
189+
def spearman_correlation(x: torch.Tensor, y: torch.Tensor):
190+
"""Compute correlation between 2 1-D vectors
191+
Args:
192+
x: Shape (N, )
193+
y: Shape (N, )
194+
"""
195+
x_rank = _get_ranks(x)
196+
y_rank = _get_ranks(y)
197+
198+
n = x.size(0)
199+
upper = 6 * torch.sum((x_rank - y_rank).pow(2))
200+
down = n * (n ** 2 - 1.0)
201+
return 1.0 - (upper / down)
202+
203+
204+
# yt_vs_sorted, yt_fs_sorted, yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs = sort_var_fits(var_y_trues, var_y_preds)
205+
206+
# TODO:
207+
##############################################################################
208+
## Adapt ESM1v model LoRA parameters based on BT loss to low N observations (fitness values)
209+
210+
211+
def corr_loss(y_true, y_pred):
212+
res_true = y_true - torch.mean(y_true)
213+
res_pred = y_pred - torch.mean(y_pred)
214+
#res_true = res_true.to(device)
215+
#res_pred = res_pred.to(device)
216+
cov = torch.mean(res_true * res_pred)
217+
var_true = torch.mean(res_true**2)
218+
var_pred = torch.mean(res_pred**2)
219+
sigma_true = torch.sqrt(var_true)
220+
sigma_pred = torch.sqrt(var_pred)
221+
return - cov / (sigma_true * sigma_pred)
222+
223+
224+
#print(encoded_seqs)
225+
#y_preds = get_y_pred_scores(encoded_seqs, attention_masks)
226+
227+
def get_batches(a, batch_size=5):
228+
a = np.array(a)
229+
orig_shape = np.shape(a)
230+
remaining = len(a) % batch_size
231+
if remaining != 0:
232+
a = a[:-remaining]
233+
if len(orig_shape) == 2:
234+
a = a.reshape(np.shape(a)[0] // batch_size, batch_size, np.shape(a)[1])
235+
else: # elif len(orig_shape) == 1:
236+
a = a.reshape(np.shape(a)[0] // batch_size, batch_size)
237+
new_shape = np.shape(a)
238+
print(f'{orig_shape} -> {new_shape} (dropped {remaining})')
239+
return torch.Tensor(a).to(device)
240+
241+
242+
def test(xs, attns, scores, loss_fn, model):
243+
print('TESTING...')
244+
for i ,(xs_b, attns_b) in enumerate(tqdm(zip(xs, attns), total=len(xs))):
245+
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
246+
with torch.no_grad():
247+
y_preds = get_y_pred_scores(xs_b, attns_b, model)
248+
if i == 0:
249+
y_preds_total = y_preds
250+
else:
251+
y_preds_total = torch.cat((y_preds_total, y_preds))
252+
loss = loss_fn(
253+
torch.flatten(scores),
254+
torch.flatten(y_preds_total)
255+
)
256+
print(f'TESTING LOSS: {float(loss.cpu()):.3f}')
257+
return torch.flatten(scores), torch.flatten(y_preds_total)
258+
259+
260+
def infer(xs, attns, model):
261+
for i ,(xs_b, attns_b) in enumerate(tqdm(zip(xs, attns), total=len(xs))):
262+
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
263+
with torch.no_grad():
264+
y_preds = get_y_pred_scores(xs_b, attns_b, model)
265+
if i == 0:
266+
y_preds_total = y_preds
267+
else:
268+
y_preds_total = torch.cat((y_preds_total, y_preds))
269+
return torch.flatten(y_preds_total)
270+
271+
272+
def train(xs, attns, scores, loss_fn, model, optimizer, n_epochs=3):
273+
for epoch in range(n_epochs):
274+
model.train()
275+
pbar = tqdm(zip(xs, attns, scores), total=len(xs))
276+
for batch, (xs_b, attns_b, scores_b) in enumerate(pbar):
277+
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
278+
y_preds = get_y_pred_scores(xs_b, attns_b, model)
279+
scores_b = scores_b.to(device)
280+
loss = loss_fn(scores_b, y_preds)
281+
loss.backward()
282+
optimizer.step()
283+
optimizer.zero_grad()
284+
#saved_params = []
285+
#for i, (name, param) in enumerate(model.named_parameters()): # 33 layers (0-32)
286+
# if 'lora' in name:
287+
# saved_params.append(torch.sum(param).clone())
288+
pbar.set_description(
289+
f"EPOCH: {epoch + 1}. Loss: {loss.item():>1f} [batch: {batch+1}/{len(xs)}: {(batch + 1) * len(xs_b):>5d}/{len(xs)*len(xs_b)}] "
290+
#f"(LoRA weight sum:{sum(saved_params):.3f})"
291+
)
292+
293+
294+
def plot_true_preds(y_true, y_pred, muts):
295+
fig, ax = plt.subplots()
296+
plt.scatter(y_true, y_pred)
297+
for yt, yp, m in zip(y_true, y_pred, muts):
298+
if yt >= 1.0 * max(y_true):
299+
plt.text(yt, yp, m)
300+
plt.text(max(y_true), max(y_pred), f'{spearmanr(y_true, y_pred)[0]:.3f}')
301+
plt.show()
302+
plt.clf()
303+
304+
305+
if __name__ == '__main__':
306+
df = pd.read_csv('CBPA2_HUMAN_Tsuboyama_2023_1O6X.csv', sep=',')
307+
print(df)
308+
muts_, seqs, scores = df['mutant'], df['mutated_sequence'], df['DMS_score']
309+
wt_seq = "VGDQVLEIVPSNEEQIKNLLQLEAQEHLQLDFWKSPTTPGETAHVRVPFVNVQAVKVFLESQGIAYSIMIED"
310+
seqs = seqs.to_list() #+ [wt_seq]
311+
scores_ = scores.to_list() #+ [1.0]
312+
print(len(seqs), len(scores))
313+
314+
print(f"Using {device} device")
315+
316+
basemodel = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3')
317+
model_reg = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3')
318+
tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3')
319+
320+
peft_config = LoraConfig(r=8, target_modules=["query", "value"])
321+
model = get_peft_model(basemodel, peft_config)
322+
model = model.to(device)
323+
324+
encoded_seqs_, attention_masks_ = get_encoded_seqs(seqs, tokenizer, max_length=len(wt_seq))
325+
#encoded_seqs, attention_masks, scores = encoded_seqs_[200:350], attention_masks_[200:350], scores_[200:350]
326+
#encoded_seqs_test, attention_masks_test, scores_test = encoded_seqs_[350:400], attention_masks_[350:400], scores_[350:400]
327+
encoded_seqs, encoded_seqs_test, attention_masks, attention_masks_test, scores, scores_test, muts, muts_test = train_test_split(
328+
encoded_seqs_, attention_masks_, scores_, muts_, train_size=0.2, shuffle = True, random_state=42)
329+
print('\n' + '-' * 60 + f'\nTrain size: {len(encoded_seqs)}, test size: {len(encoded_seqs_test)}')
330+
xs, attns, scores = get_batches(encoded_seqs), get_batches(attention_masks), get_batches(scores)
331+
xs_test, attns_test, scores_test = get_batches(encoded_seqs_test), get_batches(attention_masks_test), get_batches(scores_test)
332+
333+
print('SHAPES:', np.shape(xs), np.shape(attns), np.shape(scores))
334+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
335+
#loss_fn = torch.nn.MSELoss()
336+
loss_fn = corr_loss
337+
338+
339+
print('\n\nPRE-TRAIN-PERFORMANCE')
340+
# INITIAL TEST
341+
y_true, y_pred = test(xs=xs, attns=attns, scores=scores, loss_fn=loss_fn, model=model)
342+
y_true_test, y_pred_test = test(xs_test, attns_test, scores_test, loss_fn, model)
343+
plot_true_preds(y_true_test.cpu(), y_pred_test.cpu(), muts_test)
344+
345+
346+
print('\n\nRE-TRAINING ESM1v...)')
347+
# TRAIN
348+
# https://stackoverflow.com/questions/56360644/pytorch-runtimeerror-expected-tensor-for-argument-1-indices-to-have-scalar-t
349+
train(xs, attns, scores, loss_fn, model, optimizer, n_epochs=3)
350+
351+
352+
# TEST
353+
print('\nPOST-TRAIN-PERFORMANCE')
354+
y_true, y_pred = test(xs=xs, attns=attns, scores=scores, loss_fn=loss_fn, model=model)
355+
y_true_test, y_pred_test = test(xs_test, attns_test, scores_test, loss_fn, model)
356+
plot_true_preds(y_true_test.cpu(), y_pred_test.cpu(), muts_test)
357+

0 commit comments

Comments
 (0)