Skip to content

Commit 21e605f

Browse files
committed
Improved training progress indication and stopping functionalities
1 parent 492032d commit 21e605f

File tree

3 files changed

+51
-17
lines changed

3 files changed

+51
-17
lines changed

pypef/gui/qt_window.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,23 @@
4747
}"""
4848

4949

50+
progress_style = """
51+
QProgressBar {
52+
border: 1px solid #444;
53+
border-radius: 6px;
54+
background-color: #2b2b2b;
55+
text-align: center;
56+
height: 14px;
57+
}
58+
59+
QProgressBar::chunk {
60+
background-color: #3daee9;
61+
border-radius: 6px;
62+
}
63+
"""
64+
65+
66+
5067
class QTextEditLogger(logging.Handler, QObject):
5168
"""
5269
Thread-safe logging handler for PyQt/PySide applications.
@@ -100,10 +117,6 @@ def __init__(self, id_: int, cmd):
100117
self.__id = id_
101118
self.cmd = cmd
102119
self._abort = False
103-
104-
def abort(self):
105-
self._abort = True
106-
self.sig_msg.emit(f'Worker #{self.__id} abort requested')
107120

108121
@Slot()
109122
def work(self):
@@ -139,6 +152,7 @@ def abort_cb():
139152
self.sig_done.emit(f"Done: {self.__id}")
140153

141154
def abort(self):
155+
self._abort = True
142156
self.sig_msg.emit(f'Worker #{self.__id} notified to abort')
143157

144158

@@ -270,11 +284,11 @@ def __init__(self):
270284

271285
self.epoch_progress_bar = QProgressBar()
272286
self.epoch_progress_bar.setTextVisible(False)
273-
#self.epoch_progress_bar.setFormat("Epoch %v / %m (%p%) | Elapsed: 00:00 | ETA: --:--")
287+
self.epoch_progress_bar.setStyleSheet(progress_style)
274288

275289
self.batch_progress_bar = QProgressBar()
276290
self.batch_progress_bar.setTextVisible(False)
277-
#self.batch_progress_bar.setFormat("Batch %v / %m (%p%) | Elapsed: 00:00 | ETA: --:--")
291+
self.batch_progress_bar.setStyleSheet(progress_style)
278292

279293
# ComboBoxes ####################################################################
280294
self.box_regression_model = QComboBox()
@@ -305,6 +319,10 @@ def __init__(self):
305319

306320
# Buttons #######################################################################
307321
# Utilities
322+
self.button_abort = QPushButton("Stop training")
323+
self.button_abort.clicked.connect(self.abort_workers)
324+
self.button_abort.setStyleSheet(button_style)
325+
308326
self.button_work_dir = QPushButton("Set Working Directory")
309327
self.button_work_dir.setToolTip(
310328
"Set working directory for storing output files"
@@ -626,6 +644,8 @@ def __init__(self):
626644
layout.addWidget(self.button_work_dir, 0, 2, 1, 1)
627645
layout.addWidget(self.working_directory_text, 0, 3, 1, 1)
628646

647+
layout.addWidget(self.button_abort, 3, 5, 1, 1)
648+
629649
layout.addWidget(self.utils_text, self.shift + 3, 0, 1, 1)
630650
layout.addWidget(self.button_help, self.shift + 4, 0, 1, 1)
631651
layout.addWidget(self.button_mklsts, self.shift + 5, 0, 1, 1)
@@ -698,9 +718,12 @@ def start_main_thread(self):
698718
self.__threads.append((thread, worker))
699719
worker.moveToThread(thread)
700720

701-
worker.sig_step.connect(self.on_progress_step)
721+
worker.sig_step.connect(self.on_train_progress_step)
702722

703723
worker.sig_done.connect(self.on_worker_done)
724+
worker.sig_done.connect(thread.quit)
725+
worker.sig_done.connect(worker.deleteLater)
726+
worker.sig_done.connect(thread.deleteLater)
704727
worker.sig_msg.connect(self.logTextBox.widget.appendPlainText)
705728

706729
thread.started.connect(worker.work)
@@ -729,10 +752,11 @@ def handle_info_tick(self, info_text: str):
729752
self.device_text_out.setPlainText(new_info)
730753

731754
@Slot(dict)
732-
def on_progress_step(self, progress):
755+
def on_train_progress_step(self, progress):
733756
if self._train_start_time is None:
734757
self._train_start_time = time.time()
735758
self._last_epoch = 1
759+
self._last_epoch_time = self._train_start_time
736760
self.epoch_eta = "--:--"
737761
self.elapsed = 0
738762

@@ -749,7 +773,6 @@ def on_progress_step(self, progress):
749773

750774
if now - self._last_eta_update < 0.3:
751775
return
752-
self._last_eta_update = now
753776

754777
# Epoch ETA
755778
if self._last_epoch != progress['epoch']:
@@ -759,28 +782,32 @@ def on_progress_step(self, progress):
759782
progress['epoch_total']
760783
)
761784
self._last_epoch = progress['epoch']
785+
self._last_epoch_time = time.time()
762786

763787
# Batch ETA
764-
# TODO: Add Batch elapsed reset
788+
elapsed_since_last_epoch = now - self._last_epoch_time
765789
self.batch_eta = self.estimate_eta(
766-
self.elapsed,
790+
elapsed_since_last_epoch,
767791
progress['batch'],
768792
progress['batch_total']
769793
)
770794

771795
elapsed_str = self.format_time(self.elapsed)
796+
# Batch update is every update
797+
if not progress['epoch'] == progress['epoch_total']:
798+
delta_elapsed_str = self.format_time(elapsed_since_last_epoch)
772799

773800
# Update format text (stable width!)
774801
self.epoch_time_label.setText(
775-
f"Batch {progress['epoch']:04d} / {progress['epoch_total']:04d} "
802+
f"Epoch {progress['epoch']} / {progress['epoch_total']} "
776803
f"({int((progress['epoch'] / progress['epoch_total']) * 100)}%) "
777804
f"| Elapsed: {elapsed_str} | ETA: {self.epoch_eta}"
778805
)
779806

780807
self.batch_time_label.setText(
781-
f"Batch {progress['batch']:04d} / {progress['batch_total']:04d} "
808+
f"Batch {progress['batch']} / {progress['batch_total']} "
782809
f"({int((progress['batch'] / progress['batch_total']) * 100)}%) "
783-
f"| Elapsed: {elapsed_str} | ETA: {self.batch_eta}"
810+
f"| Elapsed: {delta_elapsed_str} | ETA: {self.batch_eta}"
784811
)
785812

786813
@Slot(int)
@@ -798,12 +825,15 @@ def abort_workers(self):
798825
# are running in a single QThread without getting callbacks from
799826
# a computing loop or so. So no qthreaded job abortions possible
800827
# without using QThread::terminate(), which should not be used.
828+
# TODO: Add functionality for new Signal-connected training/processing
829+
# for aborting (implemented for training..)
801830
self.logTextBox.widget.appendPlainText(
802831
'Asking each worker to abort...'
803832
)
804833
for thread, worker in self.__threads:
805-
thread.quit()
806-
thread.wait()
834+
#thread.quit()
835+
#thread.wait()
836+
worker.abort()
807837
# even though threads have exited, there may still be messages
808838
# on the main thread's queue (messages that threads emitted
809839
# before the abort):
@@ -833,6 +863,7 @@ def end_process(self):
833863
self.toggle_buttons(True)
834864
self.epoch_progress_bar.setValue(0)
835865
self.batch_progress_bar.setValue(0)
866+
self._train_start_time = None
836867
self.textedit_out.append("=" * 60 + "\n")
837868
self.version_text.setText("Finished...")
838869

pypef/plm/esm_lora_tune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ def esm_train(
158158
xs, attention_masks, scores = xs.to(device), attention_masks.to(device), scores.to(device)
159159
pbar_epochs = tqdm(range(1, n_epochs + 1), disable=not verbose)
160160
loss = np.nan
161-
logger.info(progress_cb) # TODO: delete
162161
for epoch in pbar_epochs:
163162
try:
164163
pbar_epochs.set_description(f'Epoch: {epoch}/{n_epochs}. Loss: {loss.detach():>1f}')
@@ -170,6 +169,8 @@ def esm_train(
170169
total=len(xs), leave=False, disable=not verbose
171170
)
172171
for batch, (xs_b, attns_b, scores_b) in enumerate(pbar_batches):
172+
if abort_cb and abort_cb():
173+
return
173174
xs_b, attns_b = xs_b.to(torch.int64), attns_b.to(torch.int64)
174175
y_preds_b = get_y_pred_scores(xs_b, attns_b, model, device=device)
175176
loss = loss_fn(scores_b, y_preds_b) / n_batch_grad_accumulations

pypef/plm/prosst_lora_tune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def prosst_train(
168168
total=len(x_sequence_batches), leave=False, disable=not verbose
169169
)
170170
for batch, (seqs_b, scores_b) in enumerate(pbar_batches):
171+
if abort_cb and abort_cb():
172+
return
171173
y_preds_b = get_logits_from_full_seqs(
172174
seqs_b, model, input_ids, attention_mask, structure_input_ids,
173175
train=True, verbose=False

0 commit comments

Comments
 (0)