|
| 1 | +# Niklas Siedhoff, 17.01.2025 |
| 2 | +# Inspired from ConFit |
| 3 | +# https://github.com/luo-group/ConFit |
| 4 | + |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch.utils.data import Dataset, DataLoader |
| 8 | +import numpy as np |
| 9 | +from scipy import stats |
| 10 | +import pandas as pd |
| 11 | +from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model |
| 12 | +from peft.utils.other import fsdp_auto_wrap_policy |
| 13 | +from transformers import EsmForMaskedLM, EsmTokenizer, EsmConfig |
| 14 | +from scipy.stats import spearmanr |
| 15 | +import os |
| 16 | +import gc |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +from matplotlib.colors import XKCD_COLORS |
| 19 | +from tqdm import tqdm |
| 20 | +from sklearn.metrics import mean_squared_error |
| 21 | +from sklearn.model_selection import train_test_split |
| 22 | + |
| 23 | + |
| 24 | +# Get cpu, gpu or mps device for training. |
| 25 | +device = ( |
| 26 | + "cuda" |
| 27 | + if torch.cuda.is_available() |
| 28 | + else "mps" |
| 29 | + if torch.backends.mps.is_available() |
| 30 | + else "cpu" |
| 31 | +) |
| 32 | +#device="cpu" |
| 33 | +# https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L91 --> |
| 34 | +# https://github.com/facebookresearch/esm/blob/main/esm/constants.py#L7 |
| 35 | +#proteinseq_toks = { |
| 36 | +# 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] |
| 37 | +#} |
| 38 | +## self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} |
| 39 | +proteinseq_toks = { |
| 40 | + 'toks': ['<null_0>', '<pad>', '<eos>', '<unk>', |
| 41 | + 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-', |
| 42 | + '<cls>', '<mask>', '<sep>'] |
| 43 | +} |
| 44 | +#print(len(proteinseq_toks['toks'])) |
| 45 | + |
| 46 | +""" |
| 47 | +
|
| 48 | +seqs_, attention_mask = tokenizer( |
| 49 | + seqs, |
| 50 | + padding='max_length', |
| 51 | + truncation=True, |
| 52 | + max_length=len(wt_seq) |
| 53 | + ).values() |
| 54 | +print(seqs_) |
| 55 | +print(np.shape(seqs_), np.shape(attention_mask)) |
| 56 | +print(muts) |
| 57 | +log_scores = [] |
| 58 | +wt_score = 0.0 |
| 59 | +with torch.no_grad(): |
| 60 | + wt_sequence_distribution = [] |
| 61 | + for i, (seq_, attn_msk) in enumerate(tqdm(zip(seqs_, attention_mask), total=len(seqs_))): |
| 62 | + out = model(torch.tensor([seq_]), torch.tensor([attn_msk]), output_hidden_states=True) # batch of 1 |
| 63 | + logits = out.logits |
| 64 | + #print(i,'/',len(seqs_), logits.shape) |
| 65 | + log_probs = torch.log_softmax(logits, dim=-1) |
| 66 | + #print(log_probs, log_probs.shape) |
| 67 | + if i == 0: |
| 68 | + for i_pos, aa_pos in enumerate(log_probs[0]): # is of shape x,y,1 so one can simply use [0] |
| 69 | + #print(f'Seq. {i}, AA pos {i_pos+2}, log Probs.: {aa_pos}') |
| 70 | + canonical_positional_aa_distribution = [] |
| 71 | + for i_aa_prob, aa_prob in enumerate(aa_pos): |
| 72 | + #print(f"AA {i_aa_prob} = {proteinseq_toks['toks'][i_aa_prob]} : {aa_prob}") |
| 73 | + if 4 <= i_aa_prob <= 23: |
| 74 | + canonical_positional_aa_distribution.append(aa_prob) |
| 75 | + wt_sequence_distribution.append(canonical_positional_aa_distribution) |
| 76 | + #print(np.shape(wt_sequence_distribution)) |
| 77 | + if i == 0: |
| 78 | + wt_score = torch.sum(log_probs) |
| 79 | + else: |
| 80 | + log_scores.append(float(torch.sum(log_probs).cpu())) |
| 81 | + #print(torch.sum(log_probs)) |
| 82 | +
|
| 83 | +print('WT score:', wt_score) |
| 84 | +df['predicted_log_score_ESM1v'] = log_scores |
| 85 | +df.to_csv('esm1_pred.csv') |
| 86 | +print(stats.spearmanr(scores, log_scores)) |
| 87 | +print('np.shape(wt_sequence_distribution):',np.shape(wt_sequence_distribution)) |
| 88 | +
|
| 89 | +var_y_trues = dict(zip(muts, scores)) |
| 90 | +var_y_preds = dict(zip(muts, log_scores)) |
| 91 | +
|
| 92 | +
|
| 93 | +fig, ax = plt.subplots(figsize=(30, 6)) |
| 94 | +k = 0 |
| 95 | +x_tick_poses, labels = [], [] |
| 96 | +for i, aa_distr in enumerate(wt_sequence_distribution): |
| 97 | + x_tick_pos = np.array(range(len(aa_distr))) + k |
| 98 | + for aa in proteinseq_toks['toks'][4:24]: |
| 99 | + labels.append(f"{i+2}{aa}") |
| 100 | + if i == 0: |
| 101 | + plt.bar(x_tick_pos, aa_distr, label=proteinseq_toks['toks'][4:24], color=XKCD_COLORS) |
| 102 | + else: |
| 103 | + plt.bar(x_tick_pos, aa_distr, color=XKCD_COLORS) |
| 104 | + x_tick_poses.append(x_tick_pos) |
| 105 | + k += len(aa_distr) + 1 |
| 106 | +plt.legend() |
| 107 | +plt.xticks(np.array(x_tick_poses).flatten(), labels, size=1, rotation=45) |
| 108 | +plt.margins(0.01) |
| 109 | +plt.tight_layout() |
| 110 | +plt.savefig('aa_esm1v_probability_distribution.png', dpi=300) |
| 111 | +
|
| 112 | +
|
| 113 | +#yt_vs_sorted, yt_fs_sorted, yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs = sort_var_fits(var_y_trues, var_y_preds) |
| 114 | +#get_loss(yp_fs_sorted_according_to_ytfs) |
| 115 | +""" |
| 116 | + |
| 117 | + |
| 118 | +def sort_var_fits(var_fits_true, var_fits_pred): |
| 119 | + |
| 120 | + def get_kvs(d): |
| 121 | + variants, fitnesses = [], [] |
| 122 | + for k, v in d.items(): |
| 123 | + variants.append(k) |
| 124 | + fitnesses.append(v) |
| 125 | + return variants, fitnesses |
| 126 | + |
| 127 | + yt_vs, yt_fs = get_kvs(var_fits_true) |
| 128 | + yp_vs, yp_fs = get_kvs(var_fits_pred) |
| 129 | + assert len(yt_vs) == len(yt_fs) == len(yp_vs) == len(yp_fs) |
| 130 | + if not yt_vs == yp_vs: |
| 131 | + yp_vs_temp, yp_fs_temp = [], [] |
| 132 | + for vart, _yt in zip(yt_vs, yt_fs): |
| 133 | + for varp, yp in zip(yp_vs, yp_fs): |
| 134 | + if vart == varp: |
| 135 | + yp_vs_temp.append(vart) |
| 136 | + yp_fs_temp.append(yp) |
| 137 | + yp_vs, yp_fs = yp_vs_temp, yp_fs_temp |
| 138 | + assert yt_vs == yp_vs |
| 139 | + ( |
| 140 | + yt_vs_sorted, yt_fs_sorted, |
| 141 | + yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs |
| 142 | + ) = [list(l) for l in zip(*sorted(zip(yt_vs, yt_fs, yp_vs, yp_fs), key=lambda x: x[1]))] |
| 143 | + |
| 144 | + assert yt_vs_sorted == yp_vs_sorted_according_to_ytfs |
| 145 | + |
| 146 | + return yt_vs_sorted, yt_fs_sorted, yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs |
| 147 | + |
| 148 | + |
| 149 | +def get_encoded_seqs(sequences, tokenizer, max_length=104): |
| 150 | + encoded_sequences, attention_masks = tokenizer( |
| 151 | + sequences, |
| 152 | + padding='max_length', |
| 153 | + truncation=True, |
| 154 | + max_length=max_length |
| 155 | + ).values() |
| 156 | + return encoded_sequences, attention_masks |
| 157 | + |
| 158 | +def get_y_pred_scores(encoded_sequences, attention_masks, model): |
| 159 | + out = model(encoded_sequences.to(device), attention_masks.to(device), output_hidden_states=True) |
| 160 | + logits = out.logits |
| 161 | + token_probs = torch.log_softmax(logits, dim=-1) |
| 162 | + log_probs = [] |
| 163 | + for i_s, sequence in enumerate(encoded_sequences): |
| 164 | + seq_log_probs = [] |
| 165 | + for i_aa, aa in enumerate(sequence): |
| 166 | + #print('Target AA:', i_aa, aa, proteinseq_toks['toks'][aa], token_probs[i_s, i_aa, aa]) |
| 167 | + if i_aa == 0: |
| 168 | + seq_log_probs = token_probs[i_s, i_aa, aa].reshape(1) # or better just use Tensor.index_select() function! |
| 169 | + #print(seq_log_probs) |
| 170 | + else: |
| 171 | + |
| 172 | + seq_log_probs = torch.cat((seq_log_probs, token_probs[i_s, i_aa, aa].reshape(1)), 0) |
| 173 | + #seq_log_probs.append(token_probs[i_s, i_aa, aa]) |
| 174 | + if i_s == 0: |
| 175 | + log_probs = torch.sum(torch.Tensor(seq_log_probs)).reshape(1) |
| 176 | + else: |
| 177 | + #print(i_s, log_probs2) |
| 178 | + log_probs = torch.cat((log_probs, torch.sum(torch.Tensor(seq_log_probs)).reshape(1)), 0) |
| 179 | + #print(log_probs2) |
| 180 | + return log_probs |
| 181 | + |
| 182 | + |
| 183 | +def _get_ranks(x: torch.Tensor) -> torch.Tensor: |
| 184 | + tmp = x.argsort() |
| 185 | + ranks = torch.zeros_like(tmp) |
| 186 | + ranks[tmp] = torch.arange(len(x)) |
| 187 | + return ranks |
| 188 | + |
| 189 | +def spearman_correlation(x: torch.Tensor, y: torch.Tensor): |
| 190 | + """Compute correlation between 2 1-D vectors |
| 191 | + Args: |
| 192 | + x: Shape (N, ) |
| 193 | + y: Shape (N, ) |
| 194 | + """ |
| 195 | + x_rank = _get_ranks(x) |
| 196 | + y_rank = _get_ranks(y) |
| 197 | + |
| 198 | + n = x.size(0) |
| 199 | + upper = 6 * torch.sum((x_rank - y_rank).pow(2)) |
| 200 | + down = n * (n ** 2 - 1.0) |
| 201 | + return 1.0 - (upper / down) |
| 202 | + |
| 203 | + |
| 204 | +# yt_vs_sorted, yt_fs_sorted, yp_vs_sorted_according_to_ytfs, yp_fs_sorted_according_to_ytfs = sort_var_fits(var_y_trues, var_y_preds) |
| 205 | + |
| 206 | +# TODO: |
| 207 | +############################################################################## |
| 208 | +## Adapt ESM1v model LoRA parameters based on BT loss to low N observations (fitness values) |
| 209 | + |
| 210 | + |
| 211 | +def corr_loss(y_true, y_pred): |
| 212 | + res_true = y_true - torch.mean(y_true) |
| 213 | + res_pred = y_pred - torch.mean(y_pred) |
| 214 | + #res_true = res_true.to(device) |
| 215 | + #res_pred = res_pred.to(device) |
| 216 | + cov = torch.mean(res_true * res_pred) |
| 217 | + var_true = torch.mean(res_true**2) |
| 218 | + var_pred = torch.mean(res_pred**2) |
| 219 | + sigma_true = torch.sqrt(var_true) |
| 220 | + sigma_pred = torch.sqrt(var_pred) |
| 221 | + return - cov / (sigma_true * sigma_pred) |
| 222 | + |
| 223 | + |
| 224 | +#print(encoded_seqs) |
| 225 | +#y_preds = get_y_pred_scores(encoded_seqs, attention_masks) |
| 226 | + |
| 227 | +def get_batches(a, batch_size=5): |
| 228 | + a = np.array(a) |
| 229 | + orig_shape = np.shape(a) |
| 230 | + remaining = len(a) % batch_size |
| 231 | + if remaining != 0: |
| 232 | + a = a[:-remaining] |
| 233 | + if len(orig_shape) == 2: |
| 234 | + a = a.reshape(np.shape(a)[0] // batch_size, batch_size, np.shape(a)[1]) |
| 235 | + else: # elif len(orig_shape) == 1: |
| 236 | + a = a.reshape(np.shape(a)[0] // batch_size, batch_size) |
| 237 | + new_shape = np.shape(a) |
| 238 | + print(f'{orig_shape} -> {new_shape} (dropped {remaining})') |
| 239 | + return torch.Tensor(a).to(device) |
| 240 | + |
| 241 | + |
| 242 | +def test(xs, attns, scores, loss_fn, model): |
| 243 | + print('TESTING...') |
| 244 | + for i ,(xs_b, attns_b) in enumerate(tqdm(zip(xs, attns), total=len(xs))): |
| 245 | + xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64) |
| 246 | + with torch.no_grad(): |
| 247 | + y_preds = get_y_pred_scores(xs_b, attns_b, model) |
| 248 | + if i == 0: |
| 249 | + y_preds_total = y_preds |
| 250 | + else: |
| 251 | + y_preds_total = torch.cat((y_preds_total, y_preds)) |
| 252 | + loss = loss_fn( |
| 253 | + torch.flatten(scores), |
| 254 | + torch.flatten(y_preds_total) |
| 255 | + ) |
| 256 | + print(f'TESTING LOSS: {float(loss.cpu()):.3f}') |
| 257 | + return torch.flatten(scores), torch.flatten(y_preds_total) |
| 258 | + |
| 259 | + |
| 260 | +def infer(xs, attns, model): |
| 261 | + for i ,(xs_b, attns_b) in enumerate(tqdm(zip(xs, attns), total=len(xs))): |
| 262 | + xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64) |
| 263 | + with torch.no_grad(): |
| 264 | + y_preds = get_y_pred_scores(xs_b, attns_b, model) |
| 265 | + if i == 0: |
| 266 | + y_preds_total = y_preds |
| 267 | + else: |
| 268 | + y_preds_total = torch.cat((y_preds_total, y_preds)) |
| 269 | + return torch.flatten(y_preds_total) |
| 270 | + |
| 271 | + |
| 272 | +def train(xs, attns, scores, loss_fn, model, optimizer, n_epochs=3): |
| 273 | + for epoch in range(n_epochs): |
| 274 | + model.train() |
| 275 | + pbar = tqdm(zip(xs, attns, scores), total=len(xs)) |
| 276 | + for batch, (xs_b, attns_b, scores_b) in enumerate(pbar): |
| 277 | + xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64) |
| 278 | + y_preds = get_y_pred_scores(xs_b, attns_b, model) |
| 279 | + scores_b = scores_b.to(device) |
| 280 | + loss = loss_fn(scores_b, y_preds) |
| 281 | + loss.backward() |
| 282 | + optimizer.step() |
| 283 | + optimizer.zero_grad() |
| 284 | + #saved_params = [] |
| 285 | + #for i, (name, param) in enumerate(model.named_parameters()): # 33 layers (0-32) |
| 286 | + # if 'lora' in name: |
| 287 | + # saved_params.append(torch.sum(param).clone()) |
| 288 | + pbar.set_description( |
| 289 | + f"EPOCH: {epoch + 1}. Loss: {loss.item():>1f} [batch: {batch+1}/{len(xs)}: {(batch + 1) * len(xs_b):>5d}/{len(xs)*len(xs_b)}] " |
| 290 | + #f"(LoRA weight sum:{sum(saved_params):.3f})" |
| 291 | + ) |
| 292 | + |
| 293 | + |
| 294 | +def plot_true_preds(y_true, y_pred, muts): |
| 295 | + fig, ax = plt.subplots() |
| 296 | + plt.scatter(y_true, y_pred) |
| 297 | + for yt, yp, m in zip(y_true, y_pred, muts): |
| 298 | + if yt >= 1.0 * max(y_true): |
| 299 | + plt.text(yt, yp, m) |
| 300 | + plt.text(max(y_true), max(y_pred), f'{spearmanr(y_true, y_pred)[0]:.3f}') |
| 301 | + plt.show() |
| 302 | + plt.clf() |
| 303 | + |
| 304 | + |
| 305 | +if __name__ == '__main__': |
| 306 | + df = pd.read_csv('CBPA2_HUMAN_Tsuboyama_2023_1O6X.csv', sep=',') |
| 307 | + print(df) |
| 308 | + muts_, seqs, scores = df['mutant'], df['mutated_sequence'], df['DMS_score'] |
| 309 | + wt_seq = "VGDQVLEIVPSNEEQIKNLLQLEAQEHLQLDFWKSPTTPGETAHVRVPFVNVQAVKVFLESQGIAYSIMIED" |
| 310 | + seqs = seqs.to_list() #+ [wt_seq] |
| 311 | + scores_ = scores.to_list() #+ [1.0] |
| 312 | + print(len(seqs), len(scores)) |
| 313 | + |
| 314 | + print(f"Using {device} device") |
| 315 | + |
| 316 | + basemodel = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3') |
| 317 | + model_reg = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3') |
| 318 | + tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_3') |
| 319 | + |
| 320 | + peft_config = LoraConfig(r=8, target_modules=["query", "value"]) |
| 321 | + model = get_peft_model(basemodel, peft_config) |
| 322 | + model = model.to(device) |
| 323 | + |
| 324 | + encoded_seqs_, attention_masks_ = get_encoded_seqs(seqs, tokenizer, max_length=len(wt_seq)) |
| 325 | + #encoded_seqs, attention_masks, scores = encoded_seqs_[200:350], attention_masks_[200:350], scores_[200:350] |
| 326 | + #encoded_seqs_test, attention_masks_test, scores_test = encoded_seqs_[350:400], attention_masks_[350:400], scores_[350:400] |
| 327 | + encoded_seqs, encoded_seqs_test, attention_masks, attention_masks_test, scores, scores_test, muts, muts_test = train_test_split( |
| 328 | + encoded_seqs_, attention_masks_, scores_, muts_, train_size=0.2, shuffle = True, random_state=42) |
| 329 | + print('\n' + '-' * 60 + f'\nTrain size: {len(encoded_seqs)}, test size: {len(encoded_seqs_test)}') |
| 330 | + xs, attns, scores = get_batches(encoded_seqs), get_batches(attention_masks), get_batches(scores) |
| 331 | + xs_test, attns_test, scores_test = get_batches(encoded_seqs_test), get_batches(attention_masks_test), get_batches(scores_test) |
| 332 | + |
| 333 | + print('SHAPES:', np.shape(xs), np.shape(attns), np.shape(scores)) |
| 334 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) |
| 335 | + #loss_fn = torch.nn.MSELoss() |
| 336 | + loss_fn = corr_loss |
| 337 | + |
| 338 | + |
| 339 | + print('\n\nPRE-TRAIN-PERFORMANCE') |
| 340 | + # INITIAL TEST |
| 341 | + y_true, y_pred = test(xs=xs, attns=attns, scores=scores, loss_fn=loss_fn, model=model) |
| 342 | + y_true_test, y_pred_test = test(xs_test, attns_test, scores_test, loss_fn, model) |
| 343 | + plot_true_preds(y_true_test.cpu(), y_pred_test.cpu(), muts_test) |
| 344 | + |
| 345 | + |
| 346 | + print('\n\nRE-TRAINING ESM1v...)') |
| 347 | + # TRAIN |
| 348 | + # https://stackoverflow.com/questions/56360644/pytorch-runtimeerror-expected-tensor-for-argument-1-indices-to-have-scalar-t |
| 349 | + train(xs, attns, scores, loss_fn, model, optimizer, n_epochs=3) |
| 350 | + |
| 351 | + |
| 352 | + # TEST |
| 353 | + print('\nPOST-TRAIN-PERFORMANCE') |
| 354 | + y_true, y_pred = test(xs=xs, attns=attns, scores=scores, loss_fn=loss_fn, model=model) |
| 355 | + y_true_test, y_pred_test = test(xs_test, attns_test, scores_test, loss_fn, model) |
| 356 | + plot_true_preds(y_true_test.cpu(), y_pred_test.cpu(), muts_test) |
| 357 | + |
0 commit comments