Skip to content

Commit 884462a

Browse files
committed
TS DCA and LLM predictions work
1 parent ddf7d3e commit 884462a

File tree

4 files changed

+89
-51
lines changed

4 files changed

+89
-51
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,5 @@ datasets/ANEH/SSM_landscape.png
418418
datasets/ANEH/SSM_landscape.csv
419419
datasets/AVGFP/model_saves/*
420420
datasets/AVGFP/Pickles/*
421+
datasets/AVGFP/DCA_Hybrid_Model_Performance_ESM1v_no_ML.png
422+
datasets/AVGFP/DCA_Hybrid_Model_Performance_ProSST_no_ML.png

pypef/hybrid/hybrid_model.py

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
from sklearn.linear_model import Ridge
3838
from sklearn.model_selection import GridSearchCV, train_test_split
3939
from scipy.optimize import differential_evolution
40-
from Bio import SeqIO, BiopythonParserWarning
41-
warnings.filterwarnings(action='ignore', category=BiopythonParserWarning)
4240

4341
from pypef.utils.variant_data import (
4442
get_sequences_from_file, get_seqs_from_var_name,
@@ -97,7 +95,7 @@ def __init__(
9795
self.llm_train_function = llm_model_input['esm1v']['llm_train_function']
9896
self.llm_inference_function = llm_model_input['esm1v']['llm_inference_function']
9997
self.llm_loss_function = llm_model_input['esm1v']['llm_loss_function']
100-
self.x_train_llm = llm_model_input['esm1v']['x_llm_train']
98+
self.x_train_llm = llm_model_input['esm1v']['x_llm']
10199
self.llm_attention_mask = llm_model_input['esm1v']['llm_attention_mask']
102100
elif len(list(llm_model_input.keys())) == 1 and list(llm_model_input.keys())[0] == 'prosst':
103101
self.llm_key = 'prosst'
@@ -107,7 +105,7 @@ def __init__(
107105
self.llm_train_function = llm_model_input['prosst']['llm_train_function']
108106
self.llm_inference_function = llm_model_input['prosst']['llm_inference_function']
109107
self.llm_loss_function = llm_model_input['prosst']['llm_loss_function']
110-
self.x_train_llm = llm_model_input['prosst']['x_llm_train']
108+
self.x_train_llm = llm_model_input['prosst']['x_llm']
111109
self.llm_attention_mask = llm_model_input['prosst']['llm_attention_mask']
112110
self.input_ids = llm_model_input['prosst']['input_ids']
113111
self.structure_input_ids = llm_model_input['prosst']['structure_input_ids']
@@ -844,7 +842,8 @@ def plmc_or_gremlin_encoding(
844842
else:
845843
model, model_type = global_model, global_model_type
846844
else:
847-
model, model_type = get_model_and_type(params_file, substitution_sep)
845+
model, model_type = get_model_and_type(
846+
params_file, substitution_sep)
848847
if model_type == 'PLMC':
849848
xs, x_wt, variants, sequences, ys_true = plmc_encoding(
850849
model, variants, sequences, ys_true, threads, verbose
@@ -867,20 +866,25 @@ def plmc_or_gremlin_encoding(
867866
)
868867
else:
869868
raise SystemError(
870-
f"Found a {model_type.lower()} model as input. Please train a new "
871-
f"hybrid model on the provided LS/TS datasets."
869+
f"Found a {model_type.lower()} model as input. Please "
870+
f"train a new hybrid model on the provided LS/TS datasets."
872871
)
873872
assert len(xs) == len(variants) == len(sequences) == len(ys_true)
874873
return xs, variants, sequences, ys_true, x_wt, model, model_type
875874

876875

877-
def gremlin_encoding(gremlin: GREMLIN, variants, sequences, ys_true, shift_pos=1, substitution_sep='/'):
876+
def gremlin_encoding(gremlin: GREMLIN, variants, sequences, ys_true,
877+
shift_pos=1, substitution_sep='/'):
878878
"""
879879
Gets X and x_wt for DCA prediction: delta_Hamiltonian respectively
880880
delta_E = np.subtract(X, x_wt), with X = encoded sequences of variants.
881881
Also removes variants, sequences, and y_trues at MSA gap positions.
882882
"""
883-
variants, sequences, ys_true = np.atleast_1d(variants), np.atleast_1d(sequences), np.atleast_1d(ys_true)
883+
variants, sequences, ys_true = (
884+
np.atleast_1d(variants),
885+
np.atleast_1d(sequences),
886+
np.atleast_1d(ys_true)
887+
)
884888
variants, sequences, ys_true = remove_gap_pos(
885889
gremlin.gaps, variants, sequences, ys_true,
886890
shift_pos=shift_pos, substitution_sep=substitution_sep
@@ -993,7 +997,8 @@ def generate_model_and_save_pkl(
993997
"""
994998
wt_seq = get_wt_sequence(wt)
995999
variants_splitted = split_variants(variants, substitution_sep)
996-
variants, ys_true, sequences = get_seqs_from_var_name(wt_seq, variants_splitted, ys_true)
1000+
variants, ys_true, sequences = get_seqs_from_var_name(
1001+
wt_seq, variants_splitted, ys_true)
9971002

9981003
xs, variants, sequences, ys_true, x_wt, _model, model_type = plmc_or_gremlin_encoding(
9991004
variants, sequences, ys_true, params_file, substitution_sep, threads)
@@ -1043,9 +1048,10 @@ def generate_model_and_save_pkl(
10431048

10441049

10451050
def llm_embedder(llm_dict, seqs):
1046-
#try:
1047-
np.shape(seqs)
1048-
#except np.shape error:
1051+
try:
1052+
np.shape(seqs)
1053+
except ValueError:
1054+
raise SystemError("Unequal input sequence length detected!")
10491055
if list(llm_dict.keys())[0] == 'esm1v':
10501056
x_llm_seqs, _attention_mask = esm_tokenize_sequences(
10511057
seqs, tokenizer=llm_dict['esm1v']['llm_tokenizer'], max_length=len(seqs[0])
@@ -1069,7 +1075,8 @@ def performance_ls_ts(
10691075
pdb_file: str | None = None,
10701076
wt_seq: str | None = None,
10711077
substitution_sep: str = '/',
1072-
label=False
1078+
label=False,
1079+
device: str| None = None
10731080
):
10741081
"""
10751082
Description
@@ -1137,23 +1144,6 @@ def performance_ls_ts(
11371144
llm_dict = esm_setup(train_sequences)
11381145
x_llm_test = llm_embedder(llm_dict, test_sequences)
11391146
elif llm == 'prosst':
1140-
if pdb_file is None:
1141-
raise SystemError(
1142-
"Running ProSST requires a PDB file input "
1143-
"for embedding sequences! Specify a PDB file "
1144-
"with the --pdb flag."
1145-
)
1146-
if wt_seq is None:
1147-
raise SystemError(
1148-
"Running ProSST requires a wild-type sequence "
1149-
"FASTA file input for embedding sequences! "
1150-
"Specify a FASTA file with the --wt flag."
1151-
)
1152-
pdb_seq = str(list(SeqIO.parse(pdb_file, "pdb-atom"))[0].seq)
1153-
assert wt_seq == pdb_seq, (
1154-
f"Wild-type sequence is not matching PDB-extracted sequence:"
1155-
f"\nWT sequence:\n{wt_seq}\nPDB sequence:\n{pdb_seq}"
1156-
)
11571147
llm_dict = prosst_setup(
11581148
wt_seq, pdb_file, sequences=train_sequences)
11591149
x_llm_test = llm_embedder(llm_dict, test_sequences)
@@ -1173,6 +1163,7 @@ def performance_ls_ts(
11731163
save_model_to_dict_pickle(hybrid_model, model_name)
11741164

11751165
elif ts_fasta is not None and model_pickle_file is not None and params_file is not None:
1166+
# # no LS provided --> statistical modeling / no ML
11761167
print(f'Taking model from saved model (Pickle file): {model_pickle_file}...')
11771168
model, model_type = get_model_and_type(model_pickle_file)
11781169
if model_type != 'Hybrid': # same as below in next elif
@@ -1193,33 +1184,60 @@ def performance_ls_ts(
11931184
substitution_sep, threads, False
11941185
)
11951186
if model.llm_model_input is not None:
1196-
if list(model.llm_model_input.keys())[0] == 'esm1v':
1197-
pass
1187+
print(f"Found hybrid model with LLM {list(model.llm_model_input.keys())[0]}...")
1188+
x_llm_test = llm_embedder(llm_dict, test_sequences)
1189+
model.hybrid_prediction(x_test, x_llm_test)
11981190
else:
11991191
y_test_pred = model.hybrid_prediction(x_test)
12001192

12011193
elif ts_fasta is not None and model_pickle_file is None: # no LS provided --> statistical modeling / no ML
1202-
print(f'No learning set provided, falling back to statistical DCA model: '
1203-
f'no adjustments of individual hybrid model parameters (beta_1 and beta_2).')
1194+
print(f"No learning set provided, falling back to statistical DCA model: "
1195+
f"no adjustments of individual hybrid model parameters (\"beta's\").")
12041196
test_sequences, test_variants, y_test = get_sequences_from_file(ts_fasta)
1205-
x_test, test_variants, test_sequences, y_test, x_wt, model, model_type = plmc_or_gremlin_encoding(
1206-
test_variants, test_sequences, y_test, params_file, substitution_sep, threads
1197+
(
1198+
x_test, test_variants, test_sequences,
1199+
y_test, x_wt, model, model_type
1200+
) = plmc_or_gremlin_encoding(
1201+
test_variants, test_sequences, y_test,
1202+
params_file, substitution_sep, threads
12071203
)
1208-
12091204
print(f"Initial test set variants: {len(test_sequences)}. "
12101205
f"Remaining: {len(test_variants)} (after removing "
12111206
f"substitutions at gap positions).")
1212-
12131207
y_test_pred = get_delta_e_statistical_model(x_test, x_wt)
1214-
save_model_to_dict_pickle(model, model_type, None, None, spearmanr(y_test, y_test_pred)[0], None)
1208+
if llm == 'esm':
1209+
llm_dict = esm_setup(test_sequences)
1210+
x_llm_test = llm_embedder(llm_dict, test_sequences)
1211+
y_test_pred_llm = llm_dict['esm1v']['llm_inference_function'](
1212+
xs=get_batches(x_llm_test, batch_size=1, dtype=int),
1213+
attention_mask=llm_dict['esm1v']['llm_attention_mask'],
1214+
model=llm_dict['esm1v']['llm_base_model'],
1215+
device=device
1216+
).cpu()
1217+
plot_y_true_vs_y_pred(
1218+
np.array(y_test), np.array(y_test_pred_llm), np.array(test_variants),
1219+
label=label, hybrid=True, name=f'ESM1v_no_ML'
1220+
)
1221+
elif llm == 'prosst':
1222+
llm_dict = prosst_setup(
1223+
wt_seq, pdb_file, sequences=test_sequences)
1224+
x_llm_test = llm_embedder(llm_dict, test_sequences)
1225+
y_test_pred_llm = llm_dict['prosst']['llm_inference_function'](
1226+
xs=x_llm_test,
1227+
model=llm_dict['prosst']['llm_base_model'],
1228+
input_ids=llm_dict['prosst']['input_ids'],
1229+
attention_mask=llm_dict['prosst']['llm_attention_mask'],
1230+
structure_input_ids=llm_dict['prosst']['structure_input_ids'],
1231+
device=device
1232+
).cpu()
1233+
plot_y_true_vs_y_pred(
1234+
np.array(y_test), np.array(y_test_pred_llm), np.array(test_variants),
1235+
label=label, hybrid=True, name=f'ProSST_no_ML'
1236+
)
1237+
save_model_to_dict_pickle(model, model_type)
12151238
model_type = f'{model_type}_no_ML'
1216-
12171239
else:
1218-
raise SystemError('No Test Set given for performance estimation.')
1219-
1220-
spearman_rho = spearmanr(y_test, y_test_pred)
1221-
print(f'Spearman Rho = {spearman_rho[0]:.3f}')
1222-
1240+
raise SystemError('No test set given for performance estimation.')
12231241
plot_y_true_vs_y_pred(
12241242
np.array(y_test), np.array(y_test_pred), np.array(test_variants),
12251243
label=label, hybrid=True, name=model_type

pypef/llm/esm_lora_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def esm_setup(sequences, device: str | None = None):
221221
'llm_train_function': esm_train,
222222
'llm_inference_function': esm_infer,
223223
'llm_loss_function': corr_loss,
224-
'x_llm_train' : x_esm,
224+
'x_llm' : x_esm,
225225
'llm_attention_mask': esm_attention_mask,
226226
'llm_tokenizer': esm_tokenizer
227227
}

pypef/llm/prosst_lora_tune.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010
from sys import path
1111
import os
1212
path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
13+
import warnings
1314

1415
import torch
1516
import numpy as np
1617
from scipy.stats import spearmanr
1718
from tqdm import tqdm
18-
1919
from transformers import AutoModelForMaskedLM, AutoTokenizer
20-
21-
2220
from peft import LoraConfig, get_peft_model
21+
from Bio import SeqIO, BiopythonParserWarning
22+
warnings.filterwarnings(action='ignore', category=BiopythonParserWarning)
2323

2424
from pypef.llm.esm_lora_tune import corr_loss, get_batches
2525
from pypef.llm.prosst_structure.quantizer import PdbQuantizer
@@ -187,6 +187,24 @@ def get_structure_quantizied(pdb_file, tokenizer, wt_seq):
187187

188188

189189
def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
190+
if wt_seq is None:
191+
raise SystemError(
192+
"Running ProSST requires a wild-type sequence "
193+
"FASTA file input for embedding sequences! "
194+
"Specify a FASTA file with the --wt flag."
195+
)
196+
if pdb_file is None:
197+
raise SystemError(
198+
"Running ProSST requires a PDB file input "
199+
"for embedding sequences! Specify a PDB file "
200+
"with the --pdb flag."
201+
)
202+
203+
pdb_seq = str(list(SeqIO.parse(pdb_file, "pdb-atom"))[0].seq)
204+
assert wt_seq == pdb_seq, (
205+
f"Wild-type sequence is not matching PDB-extracted sequence:"
206+
f"\nWT sequence:\n{wt_seq}\nPDB sequence:\n{pdb_seq}"
207+
)
190208
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
191209
prosst_vocab = prosst_tokenizer.get_vocab()
192210
prosst_base_model = prosst_base_model.to(device)
@@ -201,7 +219,7 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
201219
'llm_train_function': prosst_train,
202220
'llm_inference_function': get_logits_from_full_seqs,
203221
'llm_loss_function': corr_loss,
204-
'x_llm_train' : x_llm_train_prosst,
222+
'x_llm' : x_llm_train_prosst,
205223
'llm_attention_mask': prosst_attention_mask,
206224
'llm_vocab': prosst_vocab,
207225
'input_ids': input_ids,

0 commit comments

Comments
 (0)