Skip to content

Commit 81c485f

Browse files
committed
Update hybrid_model.py and appearance esm_lora_tune.py
1 parent 1ceb458 commit 81c485f

File tree

2 files changed

+96
-163
lines changed

2 files changed

+96
-163
lines changed

pypef/hybrid/hybrid_model.py

Lines changed: 53 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -609,69 +609,6 @@ def hybrid_prediction(
609609

610610
return self.beta1 * y_dca + self.beta2 * y_ridge + self.beta3 * y_llm + self.beta4 * y_llm_lora
611611

612-
def split_performance(
613-
self,
614-
train_size: float = 0.8,
615-
n_runs: int = 10,
616-
seed: int = 42,
617-
save_model: bool = False
618-
) -> dict:
619-
"""
620-
TODO: Update
621-
Estimates performance of the model.
622-
623-
Parameters
624-
----------
625-
train_size : int or float (default=0.8)
626-
Number of samples in the training dataset
627-
or fraction of full dataset used for training.
628-
n_runs : int (default=10)
629-
Number of different splits to perform.
630-
seed : int (default=42)
631-
Seed for random generator.
632-
save_model : bool (default=False)
633-
If True, model is saved using pickle, else not.
634-
635-
Returns
636-
-------
637-
data : dict
638-
Contains information about hybrid model parameters
639-
and performance results.
640-
"""
641-
data = {}
642-
np.random.seed(seed)
643-
644-
for r, random_state in enumerate(np.random.randint(100, size=n_runs)):
645-
x_train, x_test, y_train, y_test = train_test_split(
646-
self.X, self.y, train_size=train_size, random_state=random_state)
647-
beta_1, beta_2, reg = self.settings(x_train, y_train)
648-
if beta_2 == 0.0:
649-
alpha = np.nan
650-
else:
651-
if save_model:
652-
pickle.dumps(reg)
653-
alpha = reg.alpha
654-
data.update(
655-
{f'{len(y_train)}_{r}':
656-
{
657-
'no_run': r,
658-
'n_y_train': len(y_train),
659-
'n_y_test': len(y_test),
660-
'rnd_state': random_state,
661-
'spearman_rho': self.spearmanr(
662-
y_test, self.hybrid_prediction(
663-
x_test, reg, beta_1, beta_2
664-
)
665-
),
666-
'beta_1': beta_1,
667-
'beta_2': beta_2,
668-
'alpha': alpha
669-
}
670-
}
671-
)
672-
673-
return data
674-
675612
def ls_ts_performance(self):
676613
beta_1, beta_2, reg = self.settings(
677614
x_train=self.x_train,
@@ -744,49 +681,13 @@ def train_and_test(
744681
test_spearman_r = None
745682
return beta_1, beta_2, reg, self._spearmanr_dca, test_spearman_r
746683

747-
def get_train_sizes(self) -> np.ndarray:
748-
"""
749-
Generates a list of train sizes to perform low-n with.
750-
751-
Returns
752-
-------
753-
Numpy array of train sizes up to 80% (i.e. 0.8 * N_variants).
754-
"""
755-
eighty_percent = int(len(self.y) * 0.8)
756-
757-
train_sizes = np.sort(np.concatenate([
758-
np.arange(15, 50, 5), np.arange(50, 100, 10),
759-
np.arange(100, 150, 20), [160, 200, 250, 300, eighty_percent],
760-
np.arange(400, 1100, 100)
761-
]))
762-
763-
idx_max = np.where(train_sizes >= eighty_percent)[0][0] + 1
764-
return train_sizes[:idx_max]
765-
766-
def run(
767-
self,
768-
train_sizes: list = None,
769-
n_runs: int = 10
770-
) -> dict:
771-
"""
772-
773-
Returns
774-
----------
775-
data: dict
776-
Performances of the split with size of the
777-
training set = train_size and size of the
778-
test set = N_variants - train_size.
779-
"""
780-
data = {}
781-
for t, train_size in enumerate(train_sizes):
782-
print(f'{t + 1}/{len(train_sizes)}:{train_size}')
783-
data.update(self.split_performance(train_size=train_size, n_runs=n_runs))
784-
return data
785684

786685

787-
"""
788-
Below: Some helper functions that call or are dependent on the DCALLMHybridModel class.
789-
"""
686+
"""
687+
###########################################################################################
688+
# Below: Some helper functions that call or are dependent on the DCALLMHybridModel class. #
689+
###########################################################################################
690+
"""
790691

791692

792693
def check_model_type(model: dict | DCALLMHybridModel | PLMC | GREMLIN):
@@ -940,11 +841,11 @@ def plmc_or_gremlin_encoding(
940841
elif model_type == 'GREMLIN':
941842
if verbose:
942843
print(f"Following positions are frequent gap positions in the MSA "
943-
f"and cannot be considered for effective modeling, i.e., "
944-
f"substitutions at these positions are removed as these would be "
945-
f"predicted with wild-type fitness:\n{[int(gap) + 1 for gap in model.gaps]}.\n"
946-
f"Effective positions (N={len(model.v_idx)}) are:\n"
947-
f"{[int(v_pos) + 1 for v_pos in model.v_idx]}")
844+
f"and cannot be considered for effective modeling, i.e., "
845+
f"substitutions at these positions are removed as these would be "
846+
f"predicted with wild-type fitness:\n{[int(gap) + 1 for gap in model.gaps]}.\n"
847+
f"Effective positions (N={len(model.v_idx)}) are:\n"
848+
f"{[int(v_pos) + 1 for v_pos in model.v_idx]}")
948849
xs, x_wt, variants, sequences, ys_true = gremlin_encoding(
949850
model, variants, sequences, ys_true,
950851
shift_pos=1, substitution_sep=substitution_sep
@@ -987,7 +888,7 @@ def plmc_encoding(plmc: PLMC, variants, sequences, ys_true, threads=1, verbose=F
987888
wt_name = target_seq[0] + str(index[0]) + target_seq[0]
988889
if verbose:
989890
print(f"Using to-self-substitution '{wt_name}' as wild type reference. "
990-
f"Encoding variant sequences. This might take some time...")
891+
f"Encoding variant sequences. This might take some time...")
991892
x_wt = get_encoded_sequence(wt_name, plmc)
992893
if threads > 1:
993894
# Hyperthreading, NaNs are already being removed by the called function
@@ -1123,6 +1024,22 @@ def generate_model_and_save_pkl(
11231024
save_model_to_dict_pickle(hybrid_model, model_name, beta_1, beta_2, test_spearman_r, reg)
11241025

11251026

1027+
def llm_embedder(llm_dict, seqs):
1028+
#try:
1029+
np.shape(seqs)
1030+
#except np.shape error:
1031+
if list(llm_dict.keys())[0] == 'esm1v':
1032+
x_llm_seqs = esm_tokenize_sequences(
1033+
seqs, llm_dict['esm1v']['llm_tokenizer'], max_length=len(seqs[0])
1034+
)
1035+
elif list(llm_dict.keys())[0] == 'prosst':
1036+
x_llm_seqs = prosst_tokenize_sequences(
1037+
seqs, llm_dict['prosst']['llm_tokenizer'], max_length=len(seqs[0])
1038+
)
1039+
else:
1040+
raise SystemError(f"Unknown LLM dictionary input:\n{list(llm_dict.keys())[0]}")
1041+
return x_llm_seqs
1042+
11261043

11271044
def performance_ls_ts(
11281045
ls_fasta: str | None,
@@ -1188,25 +1105,18 @@ def performance_ls_ts(
11881105
f"substitutions at gap positions).\nInitial test set "
11891106
f"variants: {len(test_sequences)}. Remaining: {len(test_variants)} "
11901107
f"(after removing substitutions at gap positions)."
1191-
)
1192-
print('LLM:', llm)
1108+
)
11931109
if llm == 'esm':
11941110
llm_dict = esm_setup(train_sequences)
1195-
print('XX', llm_dict)
1196-
x_llm_test = esm_tokenize_sequences(
1197-
test_sequences, llm_dict['esm1v']['llm_tokenizer'], max_length=len(test_sequences[0])
1198-
)
1111+
x_llm_test = llm_embedder(llm_dict, test_sequences)
11991112
elif llm == 'prosst':
12001113
llm_dict = prosst_setup(wt_seq, pdb_file, sequences=train_sequences)
1201-
x_llm_test = prosst_tokenize_sequences(
1202-
test_sequences, llm_dict['prosst']['llm_tokenizer'], max_length=len(test_sequences[0])
1203-
)
1114+
x_llm_test = llm_embedder(llm_dict, test_sequences)
12041115
else:
12051116
llm_dict = None
12061117
x_llm_test = None
12071118
llm = ''
12081119

1209-
12101120
hybrid_model = DCALLMHybridModel(
12111121
x_train_dca=np.array(x_train),
12121122
y_train=np.array(y_train),
@@ -1245,13 +1155,11 @@ def performance_ls_ts(
12451155
)
12461156

12471157
print(f"Initial test set variants: {len(test_sequences)}. "
1248-
f"Remaining: {len(test_variants)} (after removing "
1249-
f"substitutions at gap positions).")
1158+
f"Remaining: {len(test_variants)} (after removing "
1159+
f"substitutions at gap positions).")
12501160

12511161
y_test_pred = get_delta_e_statistical_model(x_test, x_wt)
1252-
12531162
save_model_to_dict_pickle(model, model_type, None, None, spearmanr(y_test, y_test_pred)[0], None)
1254-
12551163
model_type = f'{model_type}_no_ML'
12561164

12571165
else:
@@ -1332,18 +1240,9 @@ def predict_ps( # also predicting "pmult" dict directories
13321240
model, model_type = get_model_and_type(model_pickle_file)
13331241

13341242
if model_type == 'PLMC' or model_type == 'GREMLIN':
1335-
print(f'No hybrid model provided - falling back to a statistical DCA model.')
1243+
print(f'Found {model_type} model file. No hybrid model provided - falling back to a statistical DCA model...')
13361244
elif model_type == 'Hybrid':
1337-
beta_1, beta_2, reg = model.beta_1, model.beta_2, model.regressor
1338-
if reg is None:
1339-
alpha_ = 'None'
1340-
else:
1341-
alpha_ = f'{reg.alpha:.3f}'
1342-
print(
1343-
f'Individual model weights and regressor hyperparameters:\n'
1344-
f'Hybrid model individual model contributions: Beta1 (DCA): {beta_1:.3f}, '
1345-
f'Beta2 (ML): {beta_2:.3f} (regressor: Ridge(alpha={alpha_})).'
1346-
)
1245+
print(f'Found hybrid model...')
13471246

13481247
pmult = [
13491248
'Recomb_Double_Split', 'Recomb_Triple_Split', 'Recomb_Quadruple_Split',
@@ -1365,13 +1264,16 @@ def predict_ps( # also predicting "pmult" dict directories
13651264
variants, sequences, None, model, threads=threads, verbose=False,
13661265
substitution_sep=separator)
13671266
ys_pred = get_delta_e_statistical_model(x_test, x_wt)
1368-
else: # Hybrid model input requires params from plmc or GREMLIN model
1369-
##encoding_model, encoding_model_type = get_model_and_type(params_file)
1267+
else: # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
13701268
x_test, _test_variants, *_ = plmc_or_gremlin_encoding(
13711269
variants, sequences, None, params_file,
13721270
threads=threads, verbose=False, substitution_sep=separator
13731271
)
1374-
ys_pred = model.hybrid_prediction(x_test, reg, beta_1, beta_2)
1272+
if model.llm_model_input is None:
1273+
ys_pred = model.hybrid_prediction(x_test)
1274+
else:
1275+
x_llm_test = llm_embedder(model.llm_model_input, sequences)
1276+
ys_pred = model.hybrid_prediction(np.asarray(x_test), np.asarray(x_llm_test))
13751277
for k, y in enumerate(ys_pred):
13761278
all_y_v_pred.append((ys_pred[k], variants[k]))
13771279
if negative: # sort by fitness value
@@ -1395,13 +1297,17 @@ def predict_ps( # also predicting "pmult" dict directories
13951297
variants, sequences, None, params_file,
13961298
threads=threads, verbose=False, substitution_sep=separator)
13971299
ys_pred = get_delta_e_statistical_model(xs, x_wt)
1398-
else: # Hybrid model input requires params from plmc or GREMLIN model
1300+
else: # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
13991301
xs, variants, *_ = plmc_or_gremlin_encoding(
14001302
variants, sequences, None, params_file,
14011303
threads=threads, verbose=True, substitution_sep=separator
14021304
)
1403-
ys_pred = model.hybrid_prediction(xs, reg, beta_1, beta_2)
1404-
assert len(xs) == len(variants)
1305+
if model.llm_model_input is None:
1306+
ys_pred = model.hybrid_prediction(xs)
1307+
else:
1308+
xs_llm = llm_embedder(model.llm_model_input, sequences)
1309+
ys_pred = model.hybrid_prediction(np.asarray(xs), np.asarray(xs_llm))
1310+
assert len(xs) == len(variants) == len(xs_llm) == len(ys_pred)
14051311
y_v_pred = zip(ys_pred, variants)
14061312
y_v_pred = sorted(y_v_pred, key=lambda x: x[0], reverse=True)
14071313
predictions_out(
@@ -1436,14 +1342,18 @@ def predict_directed_evolution(
14361342
if not list(xs):
14371343
return 'skip'
14381344
y_pred = get_delta_e_statistical_model(xs, x_wt)
1439-
else: # model_type == 'Hybrid': Hybrid model input requires params from PLMC or GREMLIN model
1345+
else: # model_type == 'Hybrid': Hybrid model input requires params from PLMC or GREMLIN model plus optional LLM input
14401346
xs, variant, *_ = plmc_or_gremlin_encoding(
14411347
variant, sequence, None, encoder, verbose=False, use_global_model=True
14421348
)
14431349
if not list(xs):
14441350
return 'skip'
1351+
if model.llm_model_input is None:
1352+
x_llm = None
1353+
else:
1354+
x_llm = llm_embedder(model.llm_model_input, sequence)
14451355
try:
1446-
y_pred = model.hybrid_prediction(np.atleast_2d(xs), model.regressor, model.beta_1, model.beta_2)[0]
1356+
y_pred = model.hybrid_prediction(np.atleast_2d(xs), np.atleast_2d(x_llm))[0]
14471357
except ValueError:
14481358
raise SystemError(
14491359
"Probably a different model was used for encoding than for modeling; "

0 commit comments

Comments
 (0)