Skip to content

Commit 2940f06

Browse files
committed
Working on directed evo hybrid DCA+LLM
1 parent b111555 commit 2940f06

File tree

5 files changed

+156
-41
lines changed

5 files changed

+156
-41
lines changed

.vscode/launch.json

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@
6666
]
6767
},
6868

69+
{
70+
"name": "Python: PyPEF MKPS avGFP PS",
71+
"type": "debugpy",
72+
"request": "launch",
73+
"env": {"PYTHONPATH": "${workspaceFolder}"},
74+
"program": "${workspaceFolder}/pypef/main.py",
75+
"console": "integratedTerminal",
76+
"justMyCode": true,
77+
"cwd": "${workspaceFolder}/datasets/AVGFP/",
78+
"args": [
79+
"mkps",
80+
"--wt", "P42212_F64L.fasta",
81+
"--input", "avGFP.csv"
82+
]
83+
},
84+
6985
{
7086
"name": "Python: PyPEF ml -e onehot pls_loocv",
7187
"type": "debugpy",
@@ -282,6 +298,58 @@
282298
]
283299
},
284300

301+
{
302+
"name": "Python: PyPEF hybrid/only-PS-zero-shot GREMLIN-DCA avGFP PS: ProSST",
303+
"type": "debugpy",
304+
"request": "launch",
305+
"env": {"PYTHONPATH": "${workspaceFolder}"},
306+
"program": "${workspaceFolder}/pypef/main.py",
307+
"console": "integratedTerminal",
308+
"justMyCode": true,
309+
"cwd": "${workspaceFolder}/datasets/AVGFP/",
310+
"args": [
311+
"hybrid",
312+
"-m", "HYBRIDgremlinprosst",
313+
"--ps", "avGFP_prediction_set.fasta",
314+
"--params", "GREMLIN"
315+
]
316+
},
317+
318+
{
319+
"name": "Python: PyPEF hybrid/only-PS-zero-shot GREMLIN-DCA avGFP PS: ESM1v",
320+
"type": "debugpy",
321+
"request": "launch",
322+
"env": {"PYTHONPATH": "${workspaceFolder}"},
323+
"program": "${workspaceFolder}/pypef/main.py",
324+
"console": "integratedTerminal",
325+
"justMyCode": true,
326+
"cwd": "${workspaceFolder}/datasets/AVGFP/",
327+
"args": [
328+
"hybrid",
329+
"-m", "HYBRIDgremlinesm",
330+
"--ps", "avGFP_prediction_set.fasta",
331+
"--params", "GREMLIN"
332+
]
333+
},
334+
335+
{
336+
"name": "Python: PyPEF hybrid/only-PS-zero-shot GREMLIN-DCA avGFP DirectEvo: ESM1v",
337+
"type": "debugpy",
338+
"request": "launch",
339+
"env": {"PYTHONPATH": "${workspaceFolder}"},
340+
"program": "${workspaceFolder}/pypef/main.py",
341+
"console": "integratedTerminal",
342+
"justMyCode": true,
343+
"cwd": "${workspaceFolder}/datasets/AVGFP/",
344+
"args": [
345+
"hybrid",
346+
"directevo",
347+
"-m", "HYBRIDgremlinesm",
348+
"--wt", "P42212_F64L.fasta",
349+
"--params", "GREMLIN"
350+
]
351+
},
352+
285353
{ // PLMC zero-shot steps:
286354
// 1. $pypef param_inference --params uref100_avgfp_jhmmer_119_plmc_42.6.params
287355
// 2. $pypef hybrid -t TS.fasl --params PLMC

pypef/hybrid/hybrid_model.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
alphas: np.ndarray | None = None,
8080
parameter_range: list[tuple] | None = None,
8181
batch_size: int | None = None,
82+
llm_train: bool = True,
8283
device: str | None = None,
8384
seed: int | None = None
8485
):
@@ -135,6 +136,7 @@ def __init__(
135136
if batch_size is None:
136137
batch_size = 5
137138
self.batch_size = batch_size
139+
self.llm_train = llm_train
138140
(
139141
self.ridge_opt,
140142
self.beta1,
@@ -408,7 +410,7 @@ def train_llm(self):
408410
input_ids=self.input_ids,
409411
attention_mask=self.llm_attention_mask,
410412
structure_input_ids=self.structure_input_ids,
411-
train=True,
413+
train=False,
412414
device=self.device
413415
)
414416
y_llm_ttrain = self.llm_inference_function(
@@ -417,7 +419,7 @@ def train_llm(self):
417419
input_ids=self.input_ids,
418420
attention_mask=self.llm_attention_mask,
419421
structure_input_ids=self.structure_input_ids,
420-
train=True,
422+
train=False,
421423
device=self.device
422424
)
423425
elif self.llm_key == 'esm1v':
@@ -1206,7 +1208,11 @@ def performance_ls_ts(
12061208
print(f'Hybrid performance: {spearmanr(y_test, y_test_pred)}')
12071209
save_model_to_dict_pickle(hybrid_model, model_name)
12081210

1209-
elif ts_fasta is not None and model_pickle_file is not None and params_file is not None:
1211+
elif (
1212+
ts_fasta is not None and
1213+
model_pickle_file is not None
1214+
and params_file is not None
1215+
):
12101216
# # no LS provided --> statistical modeling / no ML
12111217
print(f'Taking model from saved model (Pickle file): {model_pickle_file}...')
12121218
model, model_type = get_model_and_type(model_pickle_file)
@@ -1233,8 +1239,9 @@ def performance_ls_ts(
12331239
model.hybrid_prediction(x_test, x_llm_test)
12341240
else:
12351241
y_test_pred = model.hybrid_prediction(x_test)
1236-
1237-
elif ts_fasta is not None and model_pickle_file is None: # no LS provided --> statistical modeling / no ML
1242+
1243+
# no LS provided --> statistical modeling / no ML
1244+
elif ts_fasta is not None and model_pickle_file is None:
12381245
print(f"No learning set provided, falling back to statistical DCA model: "
12391246
f"no adjustments of individual hybrid model parameters (\"beta's\").")
12401247
test_sequences, test_variants, y_test = get_sequences_from_file(ts_fasta)
@@ -1354,7 +1361,8 @@ def predict_ps( # also predicting "pmult" dict directories
13541361
model, model_type = get_model_and_type(model_pickle_file)
13551362

13561363
if model_type == 'PLMC' or model_type == 'GREMLIN':
1357-
print(f'Found {model_type} model file. No hybrid model provided - falling back to a statistical DCA model...')
1364+
print(f'Found {model_type} model file. No hybrid model provided - '
1365+
f'falling back to a statistical DCA model...')
13581366

13591367
pmult = [
13601368
'Recomb_Double_Split', 'Recomb_Triple_Split', 'Recomb_Quadruple_Split',
@@ -1377,14 +1385,14 @@ def predict_ps( # also predicting "pmult" dict directories
13771385
substitution_sep=separator)
13781386
ys_pred = get_delta_e_statistical_model(x_test, x_wt)
13791387
else: # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
1380-
x_test, _test_variants, *_ = plmc_or_gremlin_encoding(
1388+
x_test, _test_variants, test_sequences, *_ = plmc_or_gremlin_encoding(
13811389
variants, sequences, None, params_file,
13821390
threads=threads, verbose=False, substitution_sep=separator
13831391
)
13841392
if model.llm_key is None:
13851393
ys_pred = model.hybrid_prediction(x_test)
13861394
else:
1387-
sequences = [str(seq) for seq in sequences]
1395+
sequences = [str(seq) for seq in test_sequences]
13881396
x_llm_test = llm_embedder(model.llm_model_input, sequences)
13891397
ys_pred = model.hybrid_prediction(np.asarray(x_test), np.asarray(x_llm_test))
13901398
for k, y in enumerate(ys_pred):
@@ -1404,6 +1412,7 @@ def predict_ps( # also predicting "pmult" dict directories
14041412

14051413
elif prediction_set is not None: # Predicting single FASTA file sequences
14061414
sequences, variants, _ = get_sequences_from_file(prediction_set)
1415+
print(len(sequences), len(variants))
14071416
# NaNs are already being removed by the called function
14081417
if model_type != 'Hybrid': # statistical DCA model
14091418
xs, variants, _, _, x_wt, *_ = plmc_or_gremlin_encoding(
@@ -1412,13 +1421,16 @@ def predict_ps( # also predicting "pmult" dict directories
14121421
)
14131422
ys_pred = get_delta_e_statistical_model(xs, x_wt)
14141423
else: # Hybrid model input requires params from plmc or GREMLIN model plus optional LLM input
1415-
xs, variants, *_ = plmc_or_gremlin_encoding(
1424+
print(len(variants))
1425+
xs, variants, sequences, *_ = plmc_or_gremlin_encoding(
14161426
variants, sequences, None, params_file,
14171427
threads=threads, verbose=True, substitution_sep=separator
14181428
)
1429+
print('xs len', len(xs), len(variants))
14191430
if model.llm_key is None:
14201431
ys_pred = model.hybrid_prediction(xs)
14211432
else:
1433+
sequences = [str(seq) for seq in sequences]
14221434
xs_llm = llm_embedder(model.llm_model_input, sequences)
14231435
ys_pred = model.hybrid_prediction(np.asarray(xs), np.asarray(xs_llm))
14241436
assert len(xs) == len(variants) == len(xs_llm) == len(ys_pred)
@@ -1434,7 +1446,7 @@ def predict_ps( # also predicting "pmult" dict directories
14341446
def predict_directed_evolution(
14351447
encoder: str,
14361448
variant: str,
1437-
sequence: str,
1449+
variant_sequence: str,
14381450
hybrid_model_data_pkl: str
14391451
) -> Union[str, list]:
14401452
"""
@@ -1452,27 +1464,36 @@ def predict_directed_evolution(
14521464

14531465
if model_type != 'Hybrid': # statistical DCA model
14541466
xs, variant, _, _, x_wt, *_ = plmc_or_gremlin_encoding(
1455-
variant, sequence, None, encoder, verbose=False, use_global_model=True)
1467+
variant, variant_sequence, None, encoder,
1468+
verbose=False, use_global_model=True)
14561469
if not list(xs):
14571470
return 'skip'
14581471
y_pred = get_delta_e_statistical_model(xs, x_wt)
1459-
else: # model_type == 'Hybrid': Hybrid model input requires params from PLMC or GREMLIN model plus optional LLM input
1460-
xs, variant, *_ = plmc_or_gremlin_encoding(
1461-
variant, sequence, None, encoder, verbose=False, use_global_model=True
1472+
else: # model_type == 'Hybrid': Hybrid model input requires params
1473+
#from PLMC or GREMLIN model plus optional LLM input
1474+
print(variant, variant_sequence)
1475+
xs, variant, variant_sequence, *_ = plmc_or_gremlin_encoding(
1476+
variant, variant_sequence, None, encoder,
1477+
verbose=False, use_global_model=True
14621478
)
1479+
print(variant_sequence)
14631480
if not list(xs):
14641481
return 'skip'
14651482
if model.llm_model_input is None:
14661483
x_llm = None
14671484
else:
1468-
x_llm = llm_embedder(model.llm_model_input, sequence)
1485+
x_llm = llm_embedder(model.llm_model_input, variant_sequence)
14691486
try:
1487+
print(np.shape(xs), np.shape(x_llm), np.atleast_2d(x_llm))
1488+
#exit()
14701489
y_pred = model.hybrid_prediction(np.atleast_2d(xs), np.atleast_2d(x_llm))[0]
1471-
except ValueError:
1472-
raise SystemError(
1473-
"Probably a different model was used for encoding than for modeling; "
1474-
"e.g. using a HYBRIDgremlin model in combination with parameters taken from a PLMC file."
1475-
)
1490+
except ValueError as e:
1491+
raise e # TODO: Check sequences / mutations
1492+
# raise SystemError(
1493+
# "Probably a different model was used for encoding than "
1494+
# "for modeling; e.g. using a HYBRIDgremlin model in "
1495+
# "combination with parameters taken from a PLMC file."
1496+
# )
14761497
y_pred = float(y_pred)
14771498

14781499
return [(y_pred, variant[0][1:])]

pypef/llm/prosst_lora_tune.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
def prosst_tokenize_sequences(sequences, vocab):
3131
sequences = np.atleast_1d(sequences).tolist()
3232
x_sequences = []
33-
for sequence in tqdm(sequences, desc='Tokenizing sequences for PRoSST modeling'):
33+
for sequence in tqdm(sequences, desc='Tokenizing sequences for ProSST modeling'):
3434
x_sequence = []
3535
for aa in sequence:
3636
x_sequence.append(vocab[aa])
@@ -80,11 +80,16 @@ def get_logits_from_full_seqs(
8080
if i_aa == 0:
8181
seq_log_probs = logits[i_aa, x_aa].reshape(1)
8282
else:
83-
seq_log_probs = torch.cat((seq_log_probs, logits[i_aa, x_aa].reshape(1)), 0)
83+
seq_log_probs = torch.cat(
84+
(seq_log_probs, logits[i_aa, x_aa].reshape(1)), 0)
8485
if i_s == 0:
8586
log_probs = torch.sum(torch.Tensor(seq_log_probs)).reshape(1)
8687
else:
87-
log_probs = torch.cat((log_probs, torch.sum(torch.Tensor(seq_log_probs)).reshape(1)), 0)
88+
log_probs = torch.cat((
89+
log_probs,
90+
torch.sum(torch.Tensor(seq_log_probs)).reshape(1)
91+
), 0
92+
)
8893
return log_probs
8994

9095

@@ -104,8 +109,13 @@ def prosst_train(
104109
if seed is not None:
105110
torch.manual_seed(seed)
106111
if device is None:
107-
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
108-
print(f'ProSST training using {device.upper()} device (N_Train={len(torch.flatten(score_batches))})...')
112+
device = (
113+
"cuda" if torch.cuda.is_available()
114+
else "mps" if torch.backends.mps.is_available()
115+
else "cpu"
116+
)
117+
print(f"ProSST training using {device.upper()} device "
118+
f"(N_Train={len(torch.flatten(score_batches))})...")
109119
x_sequence_batches = x_sequence_batches.to(device)
110120
score_batches = score_batches.to(device)
111121
pbar_epochs = tqdm(range(1, n_epochs + 1))
@@ -119,9 +129,13 @@ def prosst_train(
119129
pbar_epochs.set_description(f'Epoch {epoch}/{n_epochs}')
120130
model.train()
121131
y_preds_detached = []
122-
pbar_batches = tqdm(zip(x_sequence_batches, score_batches), total=len(x_sequence_batches), leave=False)
132+
pbar_batches = tqdm(zip(x_sequence_batches, score_batches),
133+
total=len(x_sequence_batches), leave=False)
123134
for batch, (seqs_b, scores_b) in enumerate(pbar_batches):
124-
y_preds_b = get_logits_from_full_seqs(seqs_b, model, input_ids, attention_mask, structure_input_ids, train=True, verbose=False)
135+
y_preds_b = get_logits_from_full_seqs(
136+
seqs_b, model, input_ids, attention_mask, structure_input_ids,
137+
train=True, verbose=False
138+
)
125139
y_preds_detached.append(y_preds_b.detach().cpu().numpy().flatten())
126140
loss = loss_fn(scores_b, y_preds_b)
127141
loss.backward()
@@ -132,7 +146,8 @@ def prosst_train(
132146
f"[batch: {batch+1}/{len(x_sequence_batches)} | "
133147
f"sequence: {(batch + 1) * len(seqs_b):>5d}/{len(x_sequence_batches) * len(seqs_b)}] "
134148
)
135-
epoch_spearman_2 = spearmanr(score_batches.cpu().numpy().flatten(), np.array(y_preds_detached).flatten())[0]
149+
epoch_spearman_2 = spearmanr(score_batches.cpu().numpy().flatten(),
150+
np.array(y_preds_detached).flatten())[0]
136151
if epoch_spearman_2 == np.nan:
137152
raise SystemError(
138153
f"No correlation between Y_true and Y_pred could be computed...\n"
@@ -143,7 +158,10 @@ def prosst_train(
143158
did_not_improve_counter = 0
144159
best_model_epoch = epoch
145160
best_model_perf = epoch_spearman_2
146-
best_model = f"model_saves/Epoch{epoch}-Ntrain{len(score_batches.cpu().numpy().flatten())}-SpearCorr{epoch_spearman_2:.3f}.pt"
161+
best_model = (
162+
f"model_saves/Epoch{epoch}-Ntrain{len(score_batches.cpu().numpy().flatten())}"
163+
f"-SpearCorr{epoch_spearman_2:.3f}.pt"
164+
)
147165
checkpoint(model, best_model)
148166
epoch_spearman_1 = epoch_spearman_2
149167
#print(f"Saved current best model as {best_model}")
@@ -167,13 +185,16 @@ def prosst_train(
167185
y_preds_train = get_logits_from_full_seqs(
168186
x_sequence_batches.flatten(start_dim=0, end_dim=1),
169187
model, input_ids, attention_mask, structure_input_ids, train=False, verbose=False)
170-
print(f'Train-->Train Performance (N={len(score_batches.cpu().flatten())}):', spearmanr(score_batches.cpu().flatten(), y_preds_train.cpu()))
188+
print(f'Train-->Train Performance (N={len(score_batches.cpu().flatten())}):',
189+
spearmanr(score_batches.cpu().flatten(), y_preds_train.cpu()))
171190
return y_preds_train.cpu()
172191

173192

174193
def get_prosst_models():
175-
prosst_base_model = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
176-
tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
194+
prosst_base_model = AutoModelForMaskedLM.from_pretrained(
195+
"AI4Protein/ProSST-2048", trust_remote_code=True)
196+
tokenizer = AutoTokenizer.from_pretrained(
197+
"AI4Protein/ProSST-2048", trust_remote_code=True)
177198
peft_config = LoraConfig(r=8, target_modules=["query", "value"])
178199
prosst_lora_model = get_peft_model(prosst_base_model, peft_config)
179200
# TODO: Check: LoRa or base model parameters better for ProSST fine-tuning and learning rate?
@@ -187,7 +208,8 @@ def get_structure_quantizied(pdb_file, tokenizer, wt_seq):
187208
tokenized_res = tokenizer([wt_seq], return_tensors='pt')
188209
input_ids = tokenized_res['input_ids']
189210
attention_mask = tokenized_res['attention_mask']
190-
structure_input_ids = torch.tensor([1, *structure_sequence_offset, 2], dtype=torch.long).unsqueeze(0)
211+
structure_input_ids = torch.tensor([1, *structure_sequence_offset, 2],
212+
dtype=torch.long).unsqueeze(0)
191213
return input_ids, attention_mask, structure_input_ids
192214

193215

@@ -214,8 +236,10 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
214236
prosst_vocab = prosst_tokenizer.get_vocab()
215237
prosst_base_model = prosst_base_model.to(device)
216238
prosst_optimizer = torch.optim.Adam(prosst_lora_model.parameters(), lr=0.0001)
217-
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(pdb_file, prosst_tokenizer, wt_seq)
218-
x_llm_train_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
239+
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(
240+
pdb_file, prosst_tokenizer, wt_seq)
241+
x_llm_train_prosst = prosst_tokenize_sequences(
242+
sequences=sequences, vocab=prosst_vocab)
219243
llm_dict_prosst = {
220244
'prosst': {
221245
'llm_base_model': prosst_base_model,

pypef/utils/directed_evolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def in_silico_de(self):
251251
predictions = predict_directed_evolution(
252252
encoder=self.dca_encoder,
253253
variant=self.s_wt[int(new_variant[:-1]) - 1] + new_variant,
254-
sequence=new_sequence,
254+
variant_sequence=new_sequence,
255255
hybrid_model_data_pkl=self.model
256256
)
257257
if predictions != 'skip':

0 commit comments

Comments
 (0)