Skip to content

Commit d8a47e7

Browse files
committed
Update func performance_ls_ts
1 parent 8aac858 commit d8a47e7

File tree

3 files changed

+93
-49
lines changed

3 files changed

+93
-49
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848
from pypef.utils.plot import plot_y_true_vs_y_pred
4949
import pypef.dca.gremlin_inference
5050
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
51-
from pypef.llm.esm_lora_tune import get_batches
51+
from pypef.llm.esm_lora_tune import esm_tokenize_sequences, get_batches, esm_setup
52+
from pypef.llm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences
5253

5354
# sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and
5455
# will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function
@@ -396,16 +397,7 @@ def train_llm(self):
396397
#x_llm_ttest_b = get_batches(self.x_llm_ttest, batch_size=self.batch_size, dtype=int)
397398
if self.llm_key == 'prosst':
398399
y_llm_ttest = self.llm_inference_function(
399-
xs=self.x_llm_ttest,
400-
model=self.llm_base_model,
401-
input_ids=self.input_ids,
402-
attention_mask=self.llm_attention_mask,
403-
structure_input_ids=self.structure_input_ids,
404-
train=True,
405-
device=self.device
406-
)
407-
y_llm_ttrain = self.llm_inference_function(
408-
xs=self.x_llm_ttrain,
400+
x_sequences=self.x_llm_ttest,
409401
model=self.llm_base_model,
410402
input_ids=self.input_ids,
411403
attention_mask=self.llm_attention_mask,
@@ -451,17 +443,8 @@ def train_llm(self):
451443
device=self.device,
452444
#seed=self.seed
453445
)
454-
y_llm_lora_ttrain = self.llm_inference_function(
455-
xs=self.x_llm_ttrain,
456-
model=self.llm_model,
457-
input_ids=self.input_ids,
458-
attention_mask=self.llm_attention_mask,
459-
structure_input_ids=self.structure_input_ids,
460-
train=True,
461-
device=self.device
462-
)
463446
y_llm_lora_ttest = self.llm_inference_function(
464-
xs=self.x_llm_ttest,
447+
x_sequences=self.x_llm_ttest,
465448
model=self.llm_model,
466449
input_ids=self.input_ids,
467450
attention_mask=self.llm_attention_mask,
@@ -575,10 +558,10 @@ def hybrid_prediction(
575558
if self.llm_attention_mask is not None:
576559
print('No LLM input for hybrid prediction but the model '
577560
'has been trained using an LLM model input.. '
578-
'Using only DCA for hybrid prediction.. This can lead '
561+
'Using only DCA for hybridprediction.. This can lead '
579562
'to unwanted prediction behavior if the hybrid model '
580563
'is trained including an LLM...')
581-
return self.beta1 * y_dca + self.beta2 * y_ridge
564+
return self.beta1 * y_dca + self.beta2
582565

583566
else:
584567
if self.llm_key == 'prosst':
@@ -601,7 +584,7 @@ def hybrid_prediction(
601584
#desc='Infering LoRA-tuned model',
602585
device=self.device).detach().cpu().numpy()
603586
elif self.llm_key == 'esm1v':
604-
x_llm_b = get_batches(x_llm, batch_size=1, dtype=int)
587+
x_llm_b = get_batches(x_llm, batch_size=self.batch_size, dtype=int)
605588
y_llm = self.llm_inference_function(
606589
x_llm_b,
607590
self.llm_attention_mask,
@@ -615,6 +598,13 @@ def hybrid_prediction(
615598
#desc='Infering LoRA-tuned model',
616599
device=self.device).detach().cpu().numpy()
617600

601+
602+
y_dca, y_ridge, y_llm, y_llm_lora = (
603+
reduce_by_batch_modulo(y_dca, batch_size=self.batch_size),
604+
reduce_by_batch_modulo(y_ridge, batch_size=self.batch_size),
605+
reduce_by_batch_modulo(y_llm, batch_size=self.batch_size),
606+
reduce_by_batch_modulo(y_llm_lora, batch_size=self.batch_size)
607+
)
618608
return self.beta1 * y_dca + self.beta2 * y_ridge + self.beta3 * y_llm + self.beta4 * y_llm_lora
619609

620610
def split_performance(
@@ -1131,12 +1121,16 @@ def generate_model_and_save_pkl(
11311121
save_model_to_dict_pickle(hybrid_model, model_name, beta_1, beta_2, test_spearman_r, reg)
11321122

11331123

1124+
11341125
def performance_ls_ts(
11351126
ls_fasta: str | None,
11361127
ts_fasta: str | None,
11371128
threads: int,
11381129
params_file: str,
11391130
model_pickle_file: str | None = None,
1131+
llm: str | None = None,
1132+
wt_seq: str | None = None,
1133+
pdb_file: str | None = None,
11401134
substitution_sep: str = '/',
11411135
label=False
11421136
):
@@ -1194,30 +1188,35 @@ def performance_ls_ts(
11941188
f"(after removing substitutions at gap positions)."
11951189
)
11961190

1191+
if llm == 'esm':
1192+
llm_dict = esm_setup(train_sequences)
1193+
x_llm_test = esm_tokenize_sequences(
1194+
test_sequences, llm_dict['llm_tokenizer'], max_length=len(test_sequences[0])
1195+
)
1196+
elif llm == 'prosst':
1197+
llm_dict = prosst_setup(wt_seq, pdb_file, sequences=train_sequences)
1198+
x_llm_test = prosst_tokenize_sequences(
1199+
test_sequences, llm_dict['llm_tokenizer'], max_length=len(test_sequences[0])
1200+
)
1201+
else:
1202+
llm_dict = None
1203+
x_llm_test = None
1204+
llm = ''
1205+
1206+
11971207
hybrid_model = DCALLMHybridModel(
1198-
x_train=np.array(x_train),
1208+
x_train_dca=np.array(x_train),
11991209
y_train=np.array(y_train),
1200-
x_test=np.array(x_test),
1201-
y_test=np.array(y_test),
1210+
llm_model_input=llm_dict,
12021211
x_wt=x_wt
12031212
)
1204-
model_name = f'HYBRID{model_type.lower()}'
1213+
model_name = f'HYBRID{model_type.lower()}{llm.lower()}'
12051214

1206-
spearman_r, reg, beta_1, beta_2 = hybrid_model.ls_ts_performance()
1207-
ys_pred = hybrid_model.hybrid_prediction(np.array(x_test), reg, beta_1, beta_2)
1215+
y_test_pred = hybrid_model.hybrid_prediction(np.array(x_test), x_llm_test)
12081216

1209-
if reg is None:
1210-
alpha_ = 'None'
1211-
else:
1212-
alpha_ = f'{reg.alpha:.3f}'
1213-
print(
1214-
f'Individual model weights and regressor hyperparameters:\n'
1215-
f'Hybrid model individual model contributions: Beta1 (DCA): '
1216-
f'{beta_1:.3f}, Beta2 (ML): {beta_2:.3f} (regressor: '
1217-
f'Ridge(alpha={alpha_}))\nTesting performance...'
1218-
)
1217+
print(f'Hybrid performance: {spearmanr(y_test, y_test_pred)}')
12191218

1220-
save_model_to_dict_pickle(hybrid_model, model_name, beta_1, beta_2, spearman_r, reg)
1219+
save_model_to_dict_pickle(hybrid_model, model_name)
12211220

12221221
elif ts_fasta is not None and model_pickle_file is not None and params_file is not None:
12231222
print(f'Taking model from saved model (Pickle file): {model_pickle_file}...')
@@ -1227,18 +1226,18 @@ def performance_ls_ts(
12271226
if model_type != 'Hybrid': # same as below in next elif
12281227
x_test, test_variants, test_sequences, y_test, x_wt, *_ = plmc_or_gremlin_encoding(
12291228
test_variants, test_sequences, y_test, model_pickle_file, substitution_sep, threads, False)
1230-
ys_pred = get_delta_e_statistical_model(x_test, x_wt)
1229+
y_test_pred = get_delta_e_statistical_model(x_test, x_wt)
12311230
else: # Hybrid model input requires params from plmc or GREMLIN model
1232-
beta_1, beta_2, reg = model.beta_1, model.beta_2, model.regressor
1231+
#beta_1, beta_2, reg = model.beta_1, model.beta_2, model.regressor
12331232
x_test, test_variants, test_sequences, y_test, *_ = plmc_or_gremlin_encoding(
12341233
test_variants, test_sequences, y_test, params_file,
12351234
substitution_sep, threads, False
12361235
)
1237-
ys_pred = model.hybrid_prediction(x_test, reg, beta_1, beta_2)
1236+
y_test_pred = model.hybrid_prediction(x_test)
12381237

12391238
elif ts_fasta is not None and model_pickle_file is None: # no LS provided --> statistical modeling / no ML
12401239
print(f'No learning set provided, falling back to statistical DCA model: '
1241-
f'no adjustments of individual hybrid model parameters (beta_1 and beta_2).')
1240+
f'no adjustments of individual hybrid model parameters (beta_1 and beta_2).')
12421241
test_sequences, test_variants, y_test = get_sequences_from_file(ts_fasta)
12431242
x_test, test_variants, test_sequences, y_test, x_wt, model, model_type = plmc_or_gremlin_encoding(
12441243
test_variants, test_sequences, y_test, params_file, substitution_sep, threads
@@ -1248,20 +1247,20 @@ def performance_ls_ts(
12481247
f"Remaining: {len(test_variants)} (after removing "
12491248
f"substitutions at gap positions).")
12501249

1251-
ys_pred = get_delta_e_statistical_model(x_test, x_wt)
1250+
y_test_pred = get_delta_e_statistical_model(x_test, x_wt)
12521251

1253-
save_model_to_dict_pickle(model, model_type, None, None, spearmanr(y_test, ys_pred)[0], None)
1252+
save_model_to_dict_pickle(model, model_type, None, None, spearmanr(y_test, y_test_pred)[0], None)
12541253

12551254
model_type = f'{model_type}_no_ML'
12561255

12571256
else:
12581257
raise SystemError('No Test Set given for performance estimation.')
12591258

1260-
spearman_rho = spearmanr(y_test, ys_pred)
1259+
spearman_rho = spearmanr(y_test, y_test_pred)
12611260
print(f'Spearman Rho = {spearman_rho[0]:.3f}')
12621261

12631262
plot_y_true_vs_y_pred(
1264-
np.array(y_test), np.array(ys_pred), np.array(test_variants),
1263+
np.array(y_test), np.array(y_test_pred), np.array(test_variants),
12651264
label=label, hybrid=True, name=model_type
12661265
)
12671266

pypef/llm/esm_lora_tune.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,23 @@ def esm_train(xs, attention_mask, scores, loss_fn, model, optimizer, n_epochs=3,
184184
)
185185
y_preds_b = y_preds_b.detach()
186186
model.train(False)
187+
188+
189+
def esm_setup(sequences, device: str | None = None):
190+
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
191+
esm_base_model = esm_base_model.to(device)
192+
x_esm, esm_attention_mask = esm_tokenize_sequences(sequences, esm_tokenizer, max_length=len(sequences[0]))
193+
llm_dict_esm = {
194+
'esm1v': {
195+
'llm_base_model': esm_base_model,
196+
'llm_model': esm_lora_model,
197+
'llm_optimizer': esm_optimizer,
198+
'llm_train_function': esm_train,
199+
'llm_inference_function': esm_infer,
200+
'llm_loss_function': corr_loss,
201+
'x_llm_train' : x_esm,
202+
'llm_attention_mask': esm_attention_mask,
203+
'llm_tokenizer': esm_tokenizer
204+
}
205+
}
206+
return llm_dict_esm

pypef/llm/prosst_lora_tune.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,31 @@ def get_structure_quantizied(pdb_file, tokenizer, wt_seq):
192192

193193

194194

195+
def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
196+
prosst_base_model, prosst_lora_model, prosst_tokenizer, prosst_optimizer = get_prosst_models()
197+
prosst_vocab = prosst_tokenizer.get_vocab()
198+
prosst_base_model = prosst_base_model.to(device)
199+
prosst_optimizer = torch.optim.Adam(prosst_lora_model.parameters(), lr=0.0001)
200+
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(pdb_file, prosst_tokenizer, wt_seq)
201+
x_llm_train_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
202+
llm_dict_prosst = {
203+
'prosst': {
204+
'llm_base_model': prosst_base_model,
205+
'llm_model': prosst_lora_model,
206+
'llm_optimizer': prosst_optimizer,
207+
'llm_train_function': prosst_train,
208+
'llm_inference_function': get_logits_from_full_seqs,
209+
'llm_loss_function': corr_loss,
210+
'x_llm_train' : x_llm_train_prosst,
211+
'llm_attention_mask': prosst_attention_mask,
212+
'input_ids': input_ids,
213+
'structure_input_ids': structure_input_ids,
214+
'llm_tokenizer': prosst_tokenizer
215+
}
216+
}
217+
return llm_dict_prosst
218+
219+
195220
if __name__ == '__main__':
196221
import pandas as pd
197222
import copy

0 commit comments

Comments
 (0)