Skip to content

Commit f9f85cd

Browse files
committed
Add LLM zero shot test option (backedn and gui qcombobox)
1 parent e94ee3c commit f9f85cd

File tree

4 files changed

+141
-107
lines changed

4 files changed

+141
-107
lines changed
Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# PyPEF - Pythonic Protein Engineering Framework
22
# https://github.com/niklases/PyPEF
33

4-
# PyPEF Qt GUI window using PySide6
4+
# Qt GUI window using PySide6
55

66
import sys
77
from os import getcwd, cpu_count, chdir
8+
import logging
89

910
from PySide6.QtCore import QObject, QThread, QSize, Qt, QRect, Signal, Slot
1011
from PySide6.QtWidgets import (
@@ -16,9 +17,6 @@
1617
from pypef.main import __doc__, run_main, logger, formatter
1718
from pypef.utils.helpers import get_device, get_vram, get_torch_version, get_gpu_info
1819

19-
import logging
20-
logger.setLevel(logging.INFO)
21-
2220

2321
button_style = """
2422
QPushButton {
@@ -72,7 +70,6 @@ def emit(self, record):
7270
self.log_signal.emit(msg)
7371

7472

75-
7673
def trap_exc_during_debug(*args):
7774
# When app raises uncaught exception, print info
7875
print(args)
@@ -123,10 +120,11 @@ def work(self):
123120
every trained epoch from the executed imported function.
124121
"""
125122
thread_name = QThread.currentThread().objectName()
126-
self.sig_msg.emit(
127-
f'Running worker #{self.__id} from thread "{thread_name}" '
128-
f'executing command: "{self.cmd}"'
129-
)
123+
#self.sig_msg.emit(
124+
# f'Running worker #{self.__id} from thread "{thread_name}" '
125+
# f'executing command: "{self.cmd}"'
126+
#)
127+
print(f"Executing command: {self.cmd}")
130128

131129
run_main(argv=self.cmd)
132130
self.sig_done.emit(f"Done: {self.__id}")
@@ -152,7 +150,7 @@ def __init__(self):
152150
self.mklsts_cv_method = ''
153151
self.c = 0
154152
self.ls_proportion = 0.8
155-
self.setMinimumSize(QSize(1400, 400))
153+
self.setMinimumSize(QSize(1400, 800))
156154
self.setWindowTitle("PyPEF GUI")
157155
self.setStyleSheet("background-color: rgb(40, 44, 52);")
158156
self.win2 = SecondWindow()
@@ -161,7 +159,7 @@ def __init__(self):
161159
self.__workers_done = None
162160
self.__threads = None
163161

164-
# Texts ########################################################################
162+
# Texts ###############################################################
165163
layout = QGridLayout(self) # MAIN LAYOUT: QGridLayout
166164
self.version_text = QLabel(f"PyPEF v. {__version__}", alignment=Qt.AlignRight)
167165
#self.ncores_text = QLabel("Single-/multiprocessing")
@@ -206,7 +204,7 @@ def __init__(self):
206204
self.logTextBox.widget.appendPlainText(
207205
f"Current working directory: {str(getcwd())}")
208206

209-
# Horizontal slider ############################################################
207+
# Horizontal slider ###################################################
210208
self.slider = QSlider(self)
211209
self.slider.setGeometry(QRect(190, 100, 200, 16))
212210
self.slider.setOrientation(Qt.Horizontal)
@@ -217,7 +215,7 @@ def __init__(self):
217215
self.slider.move(10, 105)
218216
self.slider.valueChanged.connect(self.selection_ls_proportion)
219217

220-
# ComboBoxes ########################################################################
218+
# ComboBoxes ##########################################################
221219
self.box_regression_model = QComboBox()
222220
self.regression_models = [
223221
'PLS', 'PLS_LOOCV', 'Ridge', 'Lasso', 'ElasticNet', 'SVR', 'RF', 'MLP'
@@ -231,8 +229,9 @@ def __init__(self):
231229
)
232230

233231
self.box_llm = QComboBox()
234-
self.box_llm.addItems(['ESM1v', 'ProSST'])
232+
self.box_llm.addItems(['None', 'ESM1v', 'ProSST'])
235233
self.box_llm.currentIndexChanged.connect(self.selection_llm_model)
234+
self.box_llm.setCurrentIndex(1)
236235
self.box_llm.setStyleSheet("color:white;background-color:rgb(54, 69, 79);")
237236

238237
self.box_mklsts_cv = QComboBox()
@@ -242,9 +241,8 @@ def __init__(self):
242241
])
243242
self.box_mklsts_cv.currentIndexChanged.connect(self.selection_mklsts_splits)
244243
self.box_mklsts_cv.setStyleSheet("color:white;background-color:rgb(54, 69, 79);")
245-
246244

247-
# Buttons ######################################################################
245+
# Buttons #############################################################
248246
# Utilities
249247
self.button_work_dir = QPushButton("Set Working Directory")
250248
self.button_work_dir.setToolTip(
@@ -315,6 +313,16 @@ def __init__(self):
315313
self.button_dca_predict_dca.clicked.connect(self.pypef_dca_predict)
316314
self.button_dca_predict_dca.setStyleSheet(button_style)
317315

316+
# Zero-shot LLM
317+
self.button_llm_test_zs = QPushButton("Test (LLM)")
318+
self.button_llm_test_zs.setMinimumWidth(80)
319+
self.button_llm_test_zs.setToolTip(
320+
"Test performance on any test dataset using "
321+
"the LLM model for zero-shot prediction"
322+
)
323+
self.button_llm_test_zs.clicked.connect(self.pypef_llm_test)
324+
self.button_llm_test_zs.setStyleSheet(button_style)
325+
318326
# Hybrid DCA
319327
self.button_hybrid_train_dca = QPushButton("Train (DCA)")
320328
self.button_hybrid_train_dca.setMinimumWidth(80)
@@ -544,7 +552,8 @@ def __init__(self):
544552
layout.addWidget(self.button_dca_inference_gremlin, 4, 1, 1, 1)
545553
layout.addWidget(self.button_dca_inference_gremlin_msa_info, 5, 1, 1, 1)
546554
layout.addWidget(self.button_dca_test_dca, 6, 1, 1, 1)
547-
layout.addWidget(self.button_dca_predict_dca, 7, 1, 1, 1)
555+
layout.addWidget(self.button_llm_test_zs, 7, 1, 1, 1)
556+
layout.addWidget(self.button_dca_predict_dca, 8, 1, 1, 1)
548557

549558
layout.addWidget(self.hybrid_text, 3, 2, 1, 1)
550559
layout.addWidget(self.button_hybrid_train_dca, 4, 2, 1, 1)
@@ -639,7 +648,9 @@ def end_process(self):
639648
self.toggle_buttons(True)
640649
self.textedit_out.append("=" * 60 + "\n")
641650
self.version_text.setText("Finished...")
642-
651+
652+
653+
# Box selections ##########################################################
643654
def selection_ncores(self, i):
644655
if i == 0:
645656
self.n_cores = 1
@@ -652,7 +663,7 @@ def selection_regression_model(self, i):
652663
][i]
653664

654665
def selection_llm_model(self, i):
655-
self.llm = ['esm', 'prosst'][i]
666+
self.llm = [None, 'esm', 'prosst'][i]
656667

657668
def selection_mklsts_splits(self, i):
658669
self.mklsts_cv_method = [
@@ -675,6 +686,7 @@ def set_work_dir(self):
675686
f"Changed current working directory to: {str(getcwd())}"
676687
)
677688

689+
# Layout buttons ##########################################################
678690
def pypef_help(self):
679691
self.target_button = self.button_help
680692
self.start_process()
@@ -783,6 +795,43 @@ def pypef_dca_test(self):
783795
else:
784796
self.end_process()
785797

798+
def pypef_llm_test(self):
799+
self.target_button = self.button_dca_test_dca
800+
self.start_process()
801+
test_set_file = QFileDialog.getOpenFileName(
802+
self.win2, "Select Test Set File in \"FASL\" format",
803+
filter="FASL file (*.fasl)"
804+
)[0]
805+
if test_set_file:
806+
if self.llm == 'prosst':
807+
wt_fasta_file = QFileDialog.getOpenFileName(
808+
self.win2, "Select WT FASTA File",
809+
filter="FASTA file (*.fasta *.fa)"
810+
)[0]
811+
pdb_file = QFileDialog.getOpenFileName(
812+
self.win2, "Select PDB protein structure File",
813+
filter="PDB file (*.pdb)"
814+
)[0]
815+
if wt_fasta_file and pdb_file:
816+
self.version_text.setText(
817+
"ProSST zero shot model inference..."
818+
)
819+
self.cmd = (
820+
f'hybrid --ts {test_set_file} --llm {self.llm} '
821+
f'--wt {wt_fasta_file} --pdb {pdb_file}'
822+
)
823+
self.start_threads()
824+
else:
825+
self.end_process()
826+
elif self.llm == 'esm':
827+
self.cmd = f'hybrid --ts {test_set_file} --llm {self.llm}'
828+
self.start_threads()
829+
else:
830+
self.logTextBox.widget.appendPlainText("Provide a LLM option for modeling.")
831+
self.end_process()
832+
else:
833+
self.end_process()
834+
786835
def pypef_dca_predict(self):
787836
self.target_button = self.button_dca_predict_dca
788837
self.start_process()
@@ -936,7 +985,7 @@ def pypef_dca_llm_hybrid_train(self):
936985
self.start_threads()
937986
else:
938987
self.end_process()
939-
else:
988+
elif self.llm == 'esm':
940989
if training_file and params_pkl_file:
941990
self.version_text.setText(
942991
"Hybrid (DCA+LLM-supervised) model training..."
@@ -948,6 +997,9 @@ def pypef_dca_llm_hybrid_train(self):
948997
self.start_threads()
949998
else:
950999
self.end_process()
1000+
else:
1001+
self.logTextBox.widget.appendPlainText("Provide a LLM option for modeling.")
1002+
self.end_process()
9511003

9521004
def pypef_dca_llm_hybrid_train_test(self):
9531005
self.target_button = self.button_hybrid_train_test_dca_llm
@@ -988,7 +1040,7 @@ def pypef_dca_llm_hybrid_train_test(self):
9881040
self.start_threads()
9891041
else:
9901042
self.end_process()
991-
else:
1043+
elif self.llm == 'esm':
9921044
if training_file and test_file and params_pkl_file:
9931045
self.version_text.setText(
9941046
"Hybrid (DCA+LLM-supervised) model training..."
@@ -1000,6 +1052,9 @@ def pypef_dca_llm_hybrid_train_test(self):
10001052
self.start_threads()
10011053
else:
10021054
self.end_process()
1055+
else:
1056+
self.logTextBox.widget.appendPlainText("Provide a LLM option for modeling.")
1057+
self.end_process()
10031058

10041059
def pypef_dca_llm_hybrid_test(self):
10051060
self.target_button = self.button_hybrid_test_dca_llm
@@ -1039,7 +1094,7 @@ def pypef_dca_llm_hybrid_test(self):
10391094
self.start_threads()
10401095
else:
10411096
self.end_process()
1042-
else:
1097+
elif self.llm == 'esm':
10431098
if test_file and params_pkl_file and model_file:
10441099
self.version_text.setText(
10451100
"Hybrid (DCA+LLM-supervised) model testing..."
@@ -1050,6 +1105,9 @@ def pypef_dca_llm_hybrid_test(self):
10501105
self.start_threads()
10511106
else:
10521107
self.end_process()
1108+
else:
1109+
self.logTextBox.widget.appendPlainText("Provide a LLM option for modeling.")
1110+
self.end_process()
10531111

10541112
def pypef_dca_llm_hybrid_predict(self):
10551113
self.target_button = self.button_hybrid_predict_dca_llm
@@ -1090,7 +1148,7 @@ def pypef_dca_llm_hybrid_predict(self):
10901148
self.start_threads()
10911149
else:
10921150
self.end_process()
1093-
else:
1151+
elif self.llm == 'esm':
10941152
if prediction_file and params_pkl_file and model_file:
10951153
self.version_text.setText(
10961154
"Hybrid (DCA+LLM-supervised) model training..."
@@ -1102,6 +1160,9 @@ def pypef_dca_llm_hybrid_predict(self):
11021160
self.start_threads()
11031161
else:
11041162
self.end_process()
1163+
else:
1164+
self.logTextBox.widget.appendPlainText("Provide a LLM option for modeling.")
1165+
self.end_process()
11051166

11061167
def pypef_dca_supervised_train(self):
11071168
self.target_button = self.button_supervised_train_dca

0 commit comments

Comments
 (0)