Skip to content

Commit ddf7d3e

Browse files
committed
LS -> TS Hybrid DCA+LLM works
1 parent 35ca7a8 commit ddf7d3e

File tree

6 files changed

+3865
-69
lines changed

6 files changed

+3865
-69
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,5 @@ scripts/ESM_finetuning/_Description_DMS_substitutions_data.csv
416416
scripts/ESM_finetuning/mut_performance_violin.png
417417
datasets/ANEH/SSM_landscape.png
418418
datasets/ANEH/SSM_landscape.csv
419+
datasets/AVGFP/model_saves/*
420+
datasets/AVGFP/Pickles/*

datasets/AVGFP/GFP_AEQVI.pdb

Lines changed: 3737 additions & 0 deletions
Large diffs are not rendered by default.

pypef/hybrid/hybrid_model.py

Lines changed: 119 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from sklearn.linear_model import Ridge
3838
from sklearn.model_selection import GridSearchCV, train_test_split
3939
from scipy.optimize import differential_evolution
40+
from Bio import SeqIO, BiopythonParserWarning
41+
warnings.filterwarnings(action='ignore', category=BiopythonParserWarning)
4042

4143
from pypef.utils.variant_data import (
4244
get_sequences_from_file, get_seqs_from_var_name,
@@ -327,16 +329,28 @@ def _adjust_betas(
327329
return minimizer.x
328330

329331
def get_subsplits_train(self, train_size_fit: float = 0.66):
332+
print("Getting subsplits for supervised (re-)training of models "
333+
"and for adjustment of hybrid component contribution "
334+
"weights (\"beta's\")..."
335+
)
336+
train_size_fit = int(train_size_fit * len(self.y_train))
337+
train_size_beta_adjustment = len(self.y_train) - train_size_fit
338+
print(f"Splitting training data of size {len(self.y_train)} "
339+
f"into {train_size_fit} variants for model tuning and "
340+
f"{train_size_beta_adjustment} variants for hybrid model "
341+
f"beta adjustment...")
330342
if len(self.parameter_range) == 4:
331343
# Reduce sizes by batch modulo
332-
train_size_fit = int(
333-
(train_size_fit * len(self.y_train)) -
334-
((train_size_fit * len(self.y_train)) % self.batch_size)
335-
)
336-
#train_test_size = int(
337-
# (len(self.y_train) - train_size_fit) -
338-
# ((len(self.y_train) - train_size_fit) % self.batch_size)
339-
#)
344+
n_drop = train_size_fit % self.batch_size
345+
if n_drop > 0:
346+
train_size_fit = train_size_fit - n_drop
347+
train_size_beta_adjustment = len(self.y_train) - train_size_fit
348+
print(f"Shifting {n_drop} variants from training set to "
349+
f"beta adjustment set to match batch requirements "
350+
f"of batch size {self.batch_size} for LLM retraining "
351+
f"resulting in {train_size_fit} variants for model "
352+
f"tuning and {train_size_beta_adjustment} variants "
353+
f"for hybrid model beta adjustment...")
340354
(
341355
self.x_dca_ttrain, self.x_dca_ttest,
342356
self.x_llm_ttrain, self.x_llm_ttest,
@@ -348,14 +362,6 @@ def get_subsplits_train(self, train_size_fit: float = 0.66):
348362
train_size=train_size_fit,
349363
random_state=self.seed
350364
)
351-
# Reducing by batch size modulo for X and y
352-
self.x_dca_ttrain = self.x_dca_ttrain[:train_size_fit]
353-
self.x_llm_ttrain = self.x_llm_ttrain[:train_size_fit]
354-
self.y_ttrain = self.y_ttrain[:train_size_fit]
355-
#self.x_dca_ttest = self.x_dca_ttest[:train_test_size]
356-
#self.x_llm_ttest = self.x_llm_ttest[:train_test_size]
357-
#self.y_ttest = self.y_ttest[:train_test_size]
358-
359365
else:
360366
(
361367
self.x_dca_ttrain, self.x_dca_ttest,
@@ -526,12 +532,15 @@ def train_and_optimize(self) -> tuple:
526532
if len(self.parameter_range) == 4:
527533
self.train_llm()
528534
self.beta1, self.beta2, self.beta3, self.beta4 = self._adjust_betas(
529-
self.y_ttest, self.y_dca_ttest, self.y_dca_ridge_ttest, self.y_llm_ttest, self.y_llm_lora_ttest
535+
self.y_ttest, self.y_dca_ttest, self.y_dca_ridge_ttest,
536+
self.y_llm_ttest, self.y_llm_lora_ttest
530537
)
531538
return self.beta1, self.beta2, self.beta3, self.beta4, self.ridge_opt
532539

533540
else:
534-
self.beta1, self.beta2 = self._adjust_betas(self.y_ttest, self.y_dca_ttest, self.y_dca_ridge_ttest)
541+
self.beta1, self.beta2 = self._adjust_betas(self.y_ttest,
542+
self.y_dca_ttest, self.y_dca_ridge_ttest
543+
)
535544
return self.beta1, self.beta2, self.ridge_opt
536545

537546

@@ -607,7 +616,10 @@ def hybrid_prediction(
607616
self.llm_model,
608617
device=self.device).detach().cpu().numpy()
609618

610-
return self.beta1 * y_dca + self.beta2 * y_ridge + self.beta3 * y_llm + self.beta4 * y_llm_lora
619+
return (
620+
self.beta1 * y_dca + self.beta2 * y_ridge +
621+
self.beta3 * y_llm + self.beta4 * y_llm_lora
622+
)
611623

612624
def ls_ts_performance(self):
613625
beta_1, beta_2, reg = self.settings(
@@ -724,15 +736,20 @@ def get_model_path(model: str):
724736
model_path = f'Pickles/{model}'
725737
else:
726738
raise SystemError(
727-
"Did not find specified model file in current working directory "
728-
" or /Pickles subdirectory. Make sure to train/save a model first "
729-
"(e.g., for saving a GREMLIN model, type \"pypef param_inference --msa TARGET_MSA.a2m\" "
730-
"or, for saving a plmc model, type \"pypef param_inference --params TARGET_PLMC.params\")."
739+
"Did not find specified model file in current "
740+
"working directory or /Pickles subdirectory. "
741+
"Make sure to train/save a model first (e.g., "
742+
"for saving a GREMLIN model, type \"pypef "
743+
"param_inference --msa TARGET_MSA.a2m\" or, for"
744+
"saving a plmc model, type \"pypef param_inference"
745+
" --params TARGET_PLMC.params\")."
731746
)
732747
return model_path
733748
except TypeError:
734-
raise SystemError("No provided model. "
735-
"Specify a model for DCA-based encoding.")
749+
raise SystemError(
750+
"No provided model. Specify a " \
751+
"model for DCA-based encoding."
752+
)
736753

737754

738755
def get_model_and_type(
@@ -772,11 +789,7 @@ def get_model_and_type(
772789

773790
def save_model_to_dict_pickle(
774791
model: DCALLMHybridModel | PLMC | GREMLIN,
775-
model_type: str | None = None,
776-
beta_1: float | None = None,
777-
beta_2: float | None = None,
778-
spearman_r: float | None = None,
779-
regressor: sklearn.base.BaseEstimator = None
792+
model_type: str | None = None
780793
):
781794
try:
782795
os.mkdir('Pickles')
@@ -790,11 +803,7 @@ def save_model_to_dict_pickle(
790803
pickle.dump(
791804
{
792805
'model': model,
793-
'model_type': model_type,
794-
'beta_1': beta_1,
795-
'beta_2': beta_2,
796-
'spearman_rho': spearman_r,
797-
'regressor': regressor
806+
'model_type': model_type
798807
},
799808
open(f'Pickles/{model_type}', 'wb')
800809
)
@@ -816,19 +825,21 @@ def plmc_or_gremlin_encoding(
816825
use_global_model=False
817826
):
818827
"""
819-
Decides based on the params file input type which DCA encoding to be performed, i.e.,
820-
GREMLIN or PLMC.
821-
If use_global_model==True, to avoid each time pickle model file getting loaded, which
822-
is quite inefficient when performing directed evolution, i.e., encoding of single
823-
sequences, a global model is stored at the first evolution step and used in the
824-
subsequent steps.
828+
Decides based on the params file input type which DCA encoding
829+
to be performed, i.e., GREMLIN or PLMC.
830+
If use_global_model==True, to avoid each time pickle model
831+
file getting loaded, which is quite inefficient when performing
832+
directed evolution, i.e., encoding of single sequences, a
833+
global model is stored at the first evolution step and used
834+
in the subsequent steps.
825835
"""
826836
global global_model, global_model_type
827837
if ys_true is None:
828838
ys_true = np.zeros(np.shape(sequences))
829839
if use_global_model:
830840
if global_model is None:
831-
global_model, global_model_type = get_model_and_type(params_file, substitution_sep)
841+
global_model, global_model_type = get_model_and_type(
842+
params_file, substitution_sep)
832843
model, model_type = global_model, global_model_type
833844
else:
834845
model, model_type = global_model, global_model_type
@@ -840,12 +851,16 @@ def plmc_or_gremlin_encoding(
840851
)
841852
elif model_type == 'GREMLIN':
842853
if verbose:
843-
print(f"Following positions are frequent gap positions in the MSA "
844-
f"and cannot be considered for effective modeling, i.e., "
845-
f"substitutions at these positions are removed as these would be "
846-
f"predicted with wild-type fitness:\n{[int(gap) + 1 for gap in model.gaps]}.\n"
847-
f"Effective positions (N={len(model.v_idx)}) are:\n"
848-
f"{[int(v_pos) + 1 for v_pos in model.v_idx]}")
854+
print(
855+
f"Following positions are frequent gap positions "
856+
f"in the MSA and cannot be considered for effective "
857+
f"modeling, i.e., substitutions at these positions "
858+
f"are removed as these would be predicted with "
859+
f"wild-type fitness:"
860+
f"\n{[int(gap) + 1 for gap in model.gaps]}.\n"
861+
f"Effective positions (N={len(model.v_idx)}) are:\n"
862+
f"{[int(v_pos) + 1 for v_pos in model.v_idx]}"
863+
)
849864
xs, x_wt, variants, sequences, ys_true = gremlin_encoding(
850865
model, variants, sequences, ys_true,
851866
shift_pos=1, substitution_sep=substitution_sep
@@ -920,11 +935,14 @@ def remove_gap_pos(
920935
Returns
921936
-----------
922937
variants_v
923-
Variants with substitutions at valid sequence positions, i.e., at non-gap positions
938+
Variants with substitutions at valid sequence positions,
939+
i.e., at non-gap positions
924940
sequences_v
925-
Sequences of variants with substitutions at valid sequence positions, i.e., at non-gap positions
941+
Sequences of variants with substitutions at valid sequence positions,
942+
i.e., at non-gap positions
926943
fitnesses_v
927-
Fitness values of variants with substitutions at valid sequence positions, i.e., at non-gap positions
944+
Fitness values of variants with substitutions at valid sequence positions,
945+
i.e., at non-gap positions
928946
"""
929947
variants_v, sequences_v, fitnesses_v = [], [], []
930948
valid = []
@@ -1029,12 +1047,12 @@ def llm_embedder(llm_dict, seqs):
10291047
np.shape(seqs)
10301048
#except np.shape error:
10311049
if list(llm_dict.keys())[0] == 'esm1v':
1032-
x_llm_seqs = esm_tokenize_sequences(
1033-
seqs, llm_dict['esm1v']['llm_tokenizer'], max_length=len(seqs[0])
1050+
x_llm_seqs, _attention_mask = esm_tokenize_sequences(
1051+
seqs, tokenizer=llm_dict['esm1v']['llm_tokenizer'], max_length=len(seqs[0])
10341052
)
10351053
elif list(llm_dict.keys())[0] == 'prosst':
10361054
x_llm_seqs = prosst_tokenize_sequences(
1037-
seqs, llm_dict['prosst']['llm_tokenizer'], max_length=len(seqs[0])
1055+
seqs, vocab=llm_dict['prosst']['llm_vocab']
10381056
)
10391057
else:
10401058
raise SystemError(f"Unknown LLM dictionary input:\n{list(llm_dict.keys())[0]}")
@@ -1048,8 +1066,8 @@ def performance_ls_ts(
10481066
params_file: str,
10491067
model_pickle_file: str | None = None,
10501068
llm: str | None = None,
1051-
wt_seq: str | None = None,
10521069
pdb_file: str | None = None,
1070+
wt_seq: str | None = None,
10531071
substitution_sep: str = '/',
10541072
label=False
10551073
):
@@ -1091,32 +1109,58 @@ def performance_ls_ts(
10911109
test_sequences, test_variants, y_test = get_sequences_from_file(ts_fasta)
10921110

10931111
if ls_fasta is not None and ts_fasta is not None:
1094-
train_sequences, train_variants, y_train = get_sequences_from_file(ls_fasta)
1095-
x_train, train_variants, train_sequences, y_train, x_wt, _, model_type = plmc_or_gremlin_encoding(
1096-
train_variants, train_sequences, y_train, params_file, substitution_sep, threads
1112+
train_sequences, train_variants, y_train = get_sequences_from_file(
1113+
ls_fasta)
1114+
(
1115+
x_train, train_variants, train_sequences,
1116+
y_train, x_wt, _, model_type
1117+
) = plmc_or_gremlin_encoding(
1118+
train_variants, train_sequences, y_train,
1119+
params_file, substitution_sep, threads
10971120
)
10981121

1099-
x_test, test_variants, test_sequences, y_test, *_ = plmc_or_gremlin_encoding(
1100-
test_variants, test_sequences, y_test, params_file, substitution_sep, threads, verbose=False
1122+
(
1123+
x_test, test_variants, test_sequences, y_test, *_
1124+
) = plmc_or_gremlin_encoding(
1125+
test_variants, test_sequences, y_test, params_file,
1126+
substitution_sep, threads, verbose=False
11011127
)
11021128

11031129
print(f"\nInitial training set variants: {len(train_sequences)}. "
11041130
f"Remaining: {len(train_variants)} (after removing "
11051131
f"substitutions at gap positions).\nInitial test set "
1106-
f"variants: {len(test_sequences)}. Remaining: {len(test_variants)} "
1107-
f"(after removing substitutions at gap positions)."
1132+
f"variants: {len(test_sequences)}. Remaining: "
1133+
f"{len(test_variants)} (after removing substitutions "
1134+
f"at gap positions)."
11081135
)
11091136
if llm == 'esm':
11101137
llm_dict = esm_setup(train_sequences)
11111138
x_llm_test = llm_embedder(llm_dict, test_sequences)
11121139
elif llm == 'prosst':
1113-
llm_dict = prosst_setup(wt_seq, pdb_file, sequences=train_sequences)
1140+
if pdb_file is None:
1141+
raise SystemError(
1142+
"Running ProSST requires a PDB file input "
1143+
"for embedding sequences! Specify a PDB file "
1144+
"with the --pdb flag."
1145+
)
1146+
if wt_seq is None:
1147+
raise SystemError(
1148+
"Running ProSST requires a wild-type sequence "
1149+
"FASTA file input for embedding sequences! "
1150+
"Specify a FASTA file with the --wt flag."
1151+
)
1152+
pdb_seq = str(list(SeqIO.parse(pdb_file, "pdb-atom"))[0].seq)
1153+
assert wt_seq == pdb_seq, (
1154+
f"Wild-type sequence is not matching PDB-extracted sequence:"
1155+
f"\nWT sequence:\n{wt_seq}\nPDB sequence:\n{pdb_seq}"
1156+
)
1157+
llm_dict = prosst_setup(
1158+
wt_seq, pdb_file, sequences=train_sequences)
11141159
x_llm_test = llm_embedder(llm_dict, test_sequences)
11151160
else:
11161161
llm_dict = None
11171162
x_llm_test = None
11181163
llm = ''
1119-
11201164
hybrid_model = DCALLMHybridModel(
11211165
x_train_dca=np.array(x_train),
11221166
y_train=np.array(y_train),
@@ -1132,11 +1176,19 @@ def performance_ls_ts(
11321176
print(f'Taking model from saved model (Pickle file): {model_pickle_file}...')
11331177
model, model_type = get_model_and_type(model_pickle_file)
11341178
if model_type != 'Hybrid': # same as below in next elif
1135-
x_test, test_variants, test_sequences, y_test, x_wt, *_ = plmc_or_gremlin_encoding(
1136-
test_variants, test_sequences, y_test, model_pickle_file, substitution_sep, threads, False)
1179+
(
1180+
x_test, test_variants, test_sequences,
1181+
y_test, x_wt, *_
1182+
) = plmc_or_gremlin_encoding(
1183+
test_variants, test_sequences, y_test, model_pickle_file,
1184+
substitution_sep, threads, False
1185+
)
11371186
y_test_pred = get_delta_e_statistical_model(x_test, x_wt)
11381187
else: # Hybrid model input requires params from plmc or GREMLIN model
1139-
x_test, test_variants, test_sequences, y_test, *_ = plmc_or_gremlin_encoding(
1188+
(
1189+
x_test, test_variants, test_sequences,
1190+
y_test, *_
1191+
) = plmc_or_gremlin_encoding(
11401192
test_variants, test_sequences, y_test, params_file,
11411193
substitution_sep, threads, False
11421194
)

pypef/hybrid/hybrid_run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def run_pypef_hybrid_modeling(arguments):
5454
params_file=arguments['--params'],
5555
model_pickle_file=arguments['--model'],
5656
llm=arguments['--llm'],
57+
pdb_file=arguments['--pdb'],
58+
wt_seq=get_wt_sequence(arguments['--wt']),
5759
substitution_sep=arguments['--mutation_sep'],
5860
label=arguments['--label']
5961
)

pypef/llm/prosst_lora_tune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
202202
'llm_inference_function': get_logits_from_full_seqs,
203203
'llm_loss_function': corr_loss,
204204
'x_llm_train' : x_llm_train_prosst,
205-
'llm_attention_mask': prosst_attention_mask,
205+
'llm_attention_mask': prosst_attention_mask,
206+
'llm_vocab': prosst_vocab,
206207
'input_ids': input_ids,
207208
'structure_input_ids': structure_input_ids,
208209
'llm_tokenizer': prosst_tokenizer

pypef/main.py

100644100755
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@
144144
[--ts TEST_SET] [--ps PREDICTION_SET]
145145
[--model MODEL] [--params PARAM_FILE]
146146
[--ls LEARNING_SET] [--label]
147-
[--llm LLM]
147+
[--llm LLM] [--pdb PDB_FILE] [--wt WT_FASTA]
148148
[--threads THREADS]
149149
pypef hybrid --model MODEL --params PARAM_FILE
150150
[--ts TEST_SET] [--label]
@@ -227,6 +227,7 @@
227227
--opt_iter N_ITER Number of iterations for GREMLIN-based optimization of local fields
228228
and couplings [default: 100].
229229
--params PARAM_FILE Input PLMC couplings parameter file.
230+
--pdb PDB_FILE Input protein structure file in PDB format used for ProSST LLM modeling.
230231
-u --pmult Predict for all prediction files in folder for recombinants
231232
or for diverse variants [default: False].
232233
-p --ps PREDICTION_SET Prediction set for performing predictions using a trained Model.
@@ -346,6 +347,7 @@
346347
Optional('--offset'): Use(int),
347348
Optional('--opt_iter'): Use(int),
348349
Optional('--params'): Or(None, str),
350+
Optional('--pdb'): Or(None, str),
349351
Optional('--pmult'): bool,
350352
Optional('--ps'): Or(None, str),
351353
Optional('--qdiverse'): bool,

0 commit comments

Comments
 (0)