Skip to content

Commit 492032d

Browse files
committed
First QProgressBars
1 parent 8df142c commit 492032d

25 files changed

+272
-122
lines changed

pypef/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# PyPEF - Pythonic Protein Engineering Framework
22
# https://github.com/niklases/PyPEF
33

4-
__version__ = '0.4.3'
4+
__version__ = '0.4.4-dev'

pypef/dca/gremlin_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from tqdm import tqdm
5050
import torch
5151

52-
from pypef.llm.utils import get_batches
52+
from pypef.plm.utils import get_batches
5353
from pypef.utils.variant_data import get_mismatches
5454

5555

pypef/gui/qt_window.py

Lines changed: 201 additions & 80 deletions
Large diffs are not rendered by default.

pypef/hybrid/hybrid_model.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
from pypef.utils.plot import plot_y_true_vs_y_pred
3838
import pypef.dca.gremlin_inference
3939
from pypef.dca.gremlin_inference import GREMLIN, get_delta_e_statistical_model
40-
from pypef.llm.esm_lora_tune import esm_setup, get_esm_models
41-
from pypef.llm.prosst_lora_tune import get_prosst_models, prosst_setup
42-
from pypef.llm.inference import llm_embedder, inference
43-
from pypef.llm.utils import get_batches
40+
from pypef.plm.esm_lora_tune import esm_setup, get_esm_models
41+
from pypef.plm.prosst_lora_tune import get_prosst_models, prosst_setup
42+
from pypef.plm.inference import llm_embedder, inference
43+
from pypef.plm.utils import get_batches
4444

4545
# sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and
4646
# will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function
@@ -70,7 +70,9 @@ def __init__(
7070
llm_train: bool = True,
7171
device: str | None = None,
7272
seed: int | None = None,
73-
verbose: bool = True
73+
verbose: bool = True,
74+
progress_cb=None,
75+
abort_cb=None
7476
):
7577
if llm_model_input is not None:
7678
if type(llm_model_input) is not dict:
@@ -141,6 +143,8 @@ def __init__(
141143
self.y_llm_ttest,
142144
self.y_llm_lora_ttest
143145
) = None, None, None, None, None, None, None, None, None
146+
self.progress_cb = progress_cb
147+
self.abort_cb = abort_cb
144148
self.train_and_optimize()
145149

146150
@staticmethod
@@ -465,7 +469,9 @@ def train_llm(self):
465469
n_epochs=50,
466470
device=self.device,
467471
verbose=self.verbose,
468-
raise_error_on_train_fail=False
472+
raise_error_on_train_fail=False,
473+
progress_cb=self.progress_cb,
474+
abort_cb=self.abort_cb
469475
)
470476
y_llm_lora_ttrain = self.llm_inference_function(
471477
xs=self.x_llm_ttrain,
@@ -496,7 +502,9 @@ def train_llm(self):
496502
self.llm_optimizer,
497503
n_epochs=5,
498504
device=self.device,
499-
verbose=self.verbose
505+
verbose=self.verbose,
506+
progress_cb=self.progress_cb,
507+
abort_cb=self.abort_cb
500508
)
501509
y_llm_lora_ttrain = self.llm_inference_function(
502510
xs=x_llm_ttrain_b,
@@ -802,6 +810,8 @@ def save_model_to_dict_pickle(
802810
model.llm_base_model = model.llm_base_model.state_dict()
803811
model.llm_model_input[model.llm_key]['llm_base_model'] = None
804812
model.llm_model_input[model.llm_key]['llm_model'] = None
813+
model.progress_cb = None
814+
model.abort_cb = None
805815
model_type += model.llm_key.upper()
806816
pkl_path = os.path.abspath(f'Pickles/{model_type.upper()}')
807817
pickle.dump(
@@ -986,7 +996,9 @@ def performance_ls_ts(
986996
wt_seq: str | None = None,
987997
substitution_sep: str = '/',
988998
label=False,
989-
device: str| None = None
999+
device: str| None = None,
1000+
progress_cb=None,
1001+
abort_cb=None
9901002
):
9911003
test_sequences, test_variants, y_test = get_sequences_from_file(ts_fasta)
9921004

@@ -1032,7 +1044,9 @@ def performance_ls_ts(
10321044
y_train=np.array(y_train),
10331045
llm_model_input=llm_dict,
10341046
x_wt=x_wt,
1035-
device=device
1047+
device=device,
1048+
progress_cb=progress_cb,
1049+
abort_cb=abort_cb
10361050
)
10371051
y_test_pred = hybrid_model.hybrid_prediction(np.array(x_test), x_llm_test)
10381052
logger.info(f'Hybrid performance: {spearmanr(y_test, y_test_pred)[0]:.3f} N={len(y_test)}')

pypef/hybrid/hybrid_run.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pypef.utils.low_n_mutation_extrapolation import performance_mutation_extrapolation, low_n
1616

1717

18-
def run_pypef_hybrid_modeling(arguments):
18+
def run_pypef_hybrid_modeling(arguments, progress_cb=None, abort_cb=None):
1919
threads = abs(arguments['--threads']) if arguments['--threads'] is not None else 1
2020
threads = threads + 1 if threads == 0 else threads
2121
if arguments['--params'] is not None:
@@ -52,7 +52,9 @@ def run_pypef_hybrid_modeling(arguments):
5252
pdb_file=arguments['--pdb'],
5353
wt_seq=get_wt_sequence(arguments['--wt']),
5454
substitution_sep=arguments['--mutation_sep'],
55-
label=arguments['--label']
55+
label=arguments['--label'],
56+
progress_cb=progress_cb,
57+
abort_cb=abort_cb
5658
)
5759

5860
elif arguments['--params'] and arguments['--model'] or arguments['--ps']:

pypef/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def validate(args):
417417
exit(e)
418418

419419

420-
def run_main(argv=None):
420+
def run_main(argv=None, progress_cb=None, abort_cb=None):
421421
"""
422422
Entry point for pip-installed version.
423423
Arguments are created from Docstring using docopt that
@@ -434,7 +434,7 @@ def run_main(argv=None):
434434
elif arguments['ml']:
435435
run_pypef_pure_ml(arguments)
436436
elif arguments['hybrid'] or arguments['param_inference'] or arguments['save_msa_info']:
437-
run_pypef_hybrid_modeling(arguments)
437+
run_pypef_hybrid_modeling(arguments, progress_cb=progress_cb, abort_cb=abort_cb)
438438
else:
439439
run_pypef_utils(arguments)
440440

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
hf_logging.set_verbosity_error()
3131

3232
from pypef.utils.helpers import get_device
33-
from pypef.llm.utils import corr_loss, load_model_and_tokenizer
33+
from pypef.plm.utils import corr_loss, load_model_and_tokenizer
3434

3535

3636
def get_esm_models():
@@ -143,7 +143,8 @@ def esm_infer(xs, attention_mask, model, device: str | None = None, verbose=Fals
143143
def esm_train(
144144
xs, attention_mask, scores, loss_fn, model, optimizer, n_epochs=3,
145145
device: str | None = None, seed: int | None = None,
146-
n_batch_grad_accumulations: int = 1, verbose: bool = True
146+
n_batch_grad_accumulations: int = 1, verbose: bool = True,
147+
progress_cb=None, abort_cb=None
147148
):
148149
if seed is not None:
149150
torch.manual_seed(seed)
@@ -157,6 +158,7 @@ def esm_train(
157158
xs, attention_masks, scores = xs.to(device), attention_masks.to(device), scores.to(device)
158159
pbar_epochs = tqdm(range(1, n_epochs + 1), disable=not verbose)
159160
loss = np.nan
161+
logger.info(progress_cb) # TODO: delete
160162
for epoch in pbar_epochs:
161163
try:
162164
pbar_epochs.set_description(f'Epoch: {epoch}/{n_epochs}. Loss: {loss.detach():>1f}')
@@ -171,6 +173,8 @@ def esm_train(
171173
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
172174
y_preds_b = get_y_pred_scores(xs_b, attns_b, model, device=device)
173175
loss = loss_fn(scores_b, y_preds_b) / n_batch_grad_accumulations
176+
if progress_cb:
177+
progress_cb(epoch - 1, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
174178
loss.backward()
175179
if (batch + 1) % n_batch_grad_accumulations == 0 or (batch + 1) == len(pbar_batches):
176180
optimizer.step()
@@ -180,6 +184,8 @@ def esm_train(
180184
f"[batch: {batch+1}/{len(xs)} | sequence: "
181185
f"{(batch + 1) * len(xs_b):>5d}/{len(xs) * len(xs_b)}] ({device.upper()})"
182186
)
187+
if progress_cb:
188+
progress_cb(epoch, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
183189
y_preds_b = y_preds_b.detach()
184190
model.train(False)
185191

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import numpy as np
88

99
from pypef.utils.helpers import get_device
10-
from pypef.llm.utils import get_batches
11-
from pypef.llm.esm_lora_tune import esm_setup, esm_tokenize_sequences, esm_infer
12-
from pypef.llm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences, prosst_infer
10+
from pypef.plm.utils import get_batches
11+
from pypef.plm.esm_lora_tune import esm_setup, esm_tokenize_sequences, esm_infer
12+
from pypef.plm.prosst_lora_tune import prosst_setup, prosst_tokenize_sequences, prosst_infer
1313

1414
import logging
1515
logger = logging.getLogger('pypef.llm.inference')
Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import logging
1313

14-
from pypef.llm.utils import load_model_and_tokenizer
14+
from pypef.plm.utils import load_model_and_tokenizer
1515
logger = logging.getLogger('pypef.llm.prosst_lora_tune')
1616

1717
import os
@@ -25,8 +25,8 @@
2525
from Bio import SeqIO, BiopythonParserWarning
2626
warnings.filterwarnings(action='ignore', category=BiopythonParserWarning)
2727

28-
from pypef.llm.esm_lora_tune import corr_loss
29-
from pypef.llm.prosst_structure.quantizer import PdbQuantizer
28+
from pypef.plm.esm_lora_tune import corr_loss
29+
from pypef.plm.prosst_structure.quantizer import PdbQuantizer
3030
from pypef.utils.helpers import get_device
3131

3232

@@ -139,7 +139,9 @@ def prosst_train(
139139
input_ids, attention_mask, structure_input_ids,
140140
n_epochs=50, device: str | None = None, seed: int | None = None,
141141
early_stop: int = 50, verbose: bool = True,
142-
n_batch_grad_accumulations: int = 1, raise_error_on_train_fail: bool = True):
142+
n_batch_grad_accumulations: int = 1, raise_error_on_train_fail: bool = True,
143+
progress_cb=None, abort_cb=None
144+
):
143145
if seed is not None:
144146
torch.manual_seed(seed)
145147
if device is None:
@@ -154,6 +156,7 @@ def prosst_train(
154156
best_model = None
155157
best_model_epoch = np.nan
156158
best_model_perf = np.nan
159+
loss = np.nan
157160
os.makedirs('model_saves', exist_ok=True)
158161
for epoch in pbar_epochs:
159162
if epoch == 0:
@@ -171,6 +174,8 @@ def prosst_train(
171174
)
172175
y_preds_detached.append(y_preds_b.detach().cpu().numpy().flatten())
173176
loss = loss_fn(scores_b, y_preds_b) / n_batch_grad_accumulations
177+
if progress_cb:
178+
progress_cb(epoch - 1, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
174179
loss.backward()
175180
if (batch + 1) % n_batch_grad_accumulations == 0 or (batch + 1) == len(pbar_batches):
176181
optimizer.step()
@@ -215,6 +220,8 @@ def prosst_train(
215220
pbar_epochs.set_description(
216221
f'Epoch {epoch}/{n_epochs} [SpearCorr: {epoch_spearman_2:.3f}, Loss: {loss_total:.3f}] '
217222
f'(Best epoch: {best_model_epoch}: {best_model_perf:.3f})')
223+
if progress_cb:
224+
progress_cb(epoch, batch + 1, len(pbar_epochs), len(pbar_batches), loss)
218225
if best_model is None:
219226
msg = ("Failed to train a model (probably due to the input "
220227
"data characteristics and loss/correlation being NaN).")

0 commit comments

Comments
 (0)