Skip to content

Commit e94ee3c

Browse files
committed
add mklsts cv options (flags) to GUI
1 parent d1d7b68 commit e94ee3c

File tree

5 files changed

+76
-22
lines changed

5 files changed

+76
-22
lines changed

pypef/gui/PyPEFGUIQtWindow.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(self):
149149
self.sig_start = Signal() # needed only due to PyCharm debugger bug (!)
150150
self.llm = 'esm'
151151
self.regression_model = 'PLS'
152+
self.mklsts_cv_method = ''
152153
self.c = 0
153154
self.ls_proportion = 0.8
154155
self.setMinimumSize(QSize(1400, 400))
@@ -167,6 +168,7 @@ def __init__(self):
167168
self.llm_text = QLabel("LLM")
168169
self.regression_model_text = QLabel("Regression model")
169170
self.utils_text = QLabel("Utilities")
171+
self.mklsts_cv_options_text = QLabel("Cross-validation split options")
170172
self.dca_text = QLabel("DCA (unsupervised)")
171173
self.hybrid_text = QLabel("Hybrid (supervised DCA)")
172174
self.hybrid_dca_llm_text = QLabel("Hybrid (supervised DCA+LLM)")
@@ -215,7 +217,7 @@ def __init__(self):
215217
self.slider.move(10, 105)
216218
self.slider.valueChanged.connect(self.selection_ls_proportion)
217219

218-
# Boxes ########################################################################
220+
# ComboBoxes ########################################################################
219221
self.box_regression_model = QComboBox()
220222
self.regression_models = [
221223
'PLS', 'PLS_LOOCV', 'Ridge', 'Lasso', 'ElasticNet', 'SVR', 'RF', 'MLP'
@@ -232,6 +234,15 @@ def __init__(self):
232234
self.box_llm.addItems(['ESM1v', 'ProSST'])
233235
self.box_llm.currentIndexChanged.connect(self.selection_llm_model)
234236
self.box_llm.setStyleSheet("color:white;background-color:rgb(54, 69, 79);")
237+
238+
self.box_mklsts_cv = QComboBox()
239+
self.box_mklsts_cv.addItems([
240+
'None', 'Random split', 'Modulo split',
241+
'Continuous split', 'Plot distribution'
242+
])
243+
self.box_mklsts_cv.currentIndexChanged.connect(self.selection_mklsts_splits)
244+
self.box_mklsts_cv.setStyleSheet("color:white;background-color:rgb(54, 69, 79);")
245+
235246

236247
# Buttons ######################################################################
237248
# Utilities
@@ -527,6 +538,8 @@ def __init__(self):
527538
layout.addWidget(self.button_mklsts, 5, 0, 1, 1)
528539
layout.addWidget(self.button_mkps, 6, 0, 1, 1)
529540

541+
layout.addWidget(self.mklsts_cv_options_text, 1, 1, 1, 1)
542+
layout.addWidget(self.box_mklsts_cv, 2, 1, 1, 1)
530543
layout.addWidget(self.dca_text, 3, 1, 1, 1)
531544
layout.addWidget(self.button_dca_inference_gremlin, 4, 1, 1, 1)
532545
layout.addWidget(self.button_dca_inference_gremlin_msa_info, 5, 1, 1, 1)
@@ -641,6 +654,11 @@ def selection_regression_model(self, i):
641654
def selection_llm_model(self, i):
642655
self.llm = ['esm', 'prosst'][i]
643656

657+
def selection_mklsts_splits(self, i):
658+
self.mklsts_cv_method = [
659+
'', '--random', '--modulo', '--cont', '--plot'
660+
][i]
661+
644662
def selection_ls_proportion(self, value):
645663
self.ls_proportion = value / 100
646664
self.slider_text.setText(
@@ -680,7 +698,7 @@ def pypef_mklsts(self):
680698
self.version_text.setText("Running MKLSTS...")
681699
self.cmd = (
682700
f'mklsts --wt {wt_fasta_file} --input {csv_variant_file} '
683-
f'--ls_proportion {self.ls_proportion}'
701+
f'--ls_proportion {self.ls_proportion} {self.mklsts_cv_method}'
684702
)
685703
self.start_threads()
686704
else:

pypef/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@
173173
--all Finally training on all data [default: False].
174174
--conc Concatenating mutational level variants for predicting variants
175175
from next higher level [default: False].
176+
--cont Continuous splits in five-fold cross-validation fashion that
177+
split the data based on the positions of mutations.
176178
--csvaa Directed evolution csv amino acid substitutions,
177179
requires flag "--usecsv" [default: False].
178180
--ddiverse Create/predict double natural diverse variants [default: False].
@@ -197,6 +199,8 @@
197199
--llm LLM LLM model to use for hybrid modeling next to DCA (options are 'ESM1v' and 'ProSST').
198200
-m --model MODEL Model (pickle file) for plotting of validation or for
199201
performing predictions.
202+
--modulo Modulo-like splits in five-fold cross-validation fashion that
203+
split the data based on the positions of mutations.
200204
--msa MSA_FILE Multiple sequence alignment (MSA) in FASTA or A2M format for
201205
inferring DCA parameters.
202206
--mutation_sep MUTATION_SEP Mutation separator [default: /].
@@ -214,6 +218,9 @@
214218
and couplings [default: 100].
215219
--params PARAM_FILE Input PLMC couplings parameter file.
216220
--pdb PDB_FILE Input protein structure file in PDB format used for ProSST LLM modeling.
221+
--plot Plot different five-fold dataset split distributions performed when using
222+
the flags --random, --modulo, --cont with the mklsts command.
223+
--random Random splits in five-fold cross-validation fashion.
217224
-u --pmult Predict for all prediction files in folder for recombinants
218225
or for diverse variants [default: False].
219226
-p --ps PREDICTION_SET Prediction set for performing predictions using a trained Model.

pypef/utils/learning_test_sets.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def get_variants(
116116
df,
117117
amino_acids,
118118
wild_type_sequence,
119-
mutation_sep: str = '/'
119+
mutation_sep: str = '/',
120+
verbose=True
120121
):
121122
"""
122123
Gets variants and divides and counts the variant data for single substituted
@@ -202,12 +203,13 @@ def get_variants(
202203
single_variants.append([full_variant])
203204
if i not in index_lower:
204205
index_lower.append(i)
205-
logger.info(
206-
'Single (for mklsts if provided plus WT): {}, Double: {}, Triple: {}, Quadruple: {}, Quintuple: {}, '
207-
'Sextuple: {}, Septuple: {}, Octuple: {}, Nonuple: {}, Decuple: {}, Higher (>Decuple): {}'.format(
208-
single, double, triple, quadruple, quintuple, sextuple, septuple, octuple, nonuple, decuple, higher
206+
if verbose:
207+
logger.info(
208+
'Single (for mklsts if provided plus WT): {}, Double: {}, Triple: {}, Quadruple: {}, Quintuple: {}, '
209+
'Sextuple: {}, Septuple: {}, Octuple: {}, Nonuple: {}, Decuple: {}, Higher (>Decuple): {}'.format(
210+
single, double, triple, quadruple, quintuple, sextuple, septuple, octuple, nonuple, decuple, higher
211+
)
209212
)
210-
)
211213
for vals in y[index_higher]:
212214
higher_values.append(vals)
213215
for vals in y[index_lower]:

pypef/utils/split.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ def __init__(
3333
self.n_cv = n_cv
3434
if type(df_or_csv_file) == pd.DataFrame:
3535
self.df = df_or_csv_file
36+
self.fig_path = path.abspath('CV_split_pos_aa_distr.png')
3637
else:
3738
self.df = pd.read_csv(self.df_or_csv_file, sep=self.csv_separator)
38-
print(f'Dataframe size: {self.df.shape[0]}')
39+
self.fig_path = path.abspath(path.splitext(path.basename(
40+
self.df_or_csv_file))[0] + '_pos_aa_distr.png')
41+
logger.info(f'Dataframe size: {self.df.shape[0]}')
3942
self.random_splits_train_indices_combined, self.random_splits_test_indices_combined = None, None
4043
self.modulo_splits_train_indices_combined, self.modulo_splits_test_indices_combined = None, None
4144
self.cont_splits_train_indices_combined, self.cont_splits_test_indices_combined = None, None
@@ -56,13 +59,14 @@ def order_by_pos(self):
5659
if single_mut_idxs:
5760
self.df = self.df.loc[single_mut_idxs, :]
5861
if len(single_mut_idxs) != self.df.size:
59-
print(f'Removed multimutated variants from dataframe... '
62+
logger.info(f'Removed multimutated variants from dataframe... '
6063
f'new dataframe size: {self.df.shape[0]}')
6164
if self.mutation_column is None:
6265
variants = self.df.iloc[:, 0].to_list()
6366
else:
6467
variants = self.df[self.mutation_column].to_list()
65-
self.df.loc[:, 'variant_pos'] = [int(v[1:-1]) for v in variants]
68+
self.df.reset_index(drop=True, inplace=True)
69+
self.df['variant_pos'] = [int(v[1:-1]) for v in variants]
6670
self.df['substitutions'] = [v[-1] for v in variants]
6771
self.df.sort_values(['variant_pos', 'substitutions'], ascending=[True, True], inplace=True)
6872
self.min_pos, self.max_pos = self.df['variant_pos'].to_numpy()[0], self.df['variant_pos'].to_numpy()[-1]
@@ -165,8 +169,12 @@ def get_all_split_indices(self):
165169
def _get_df_split_data(self, combined_train_indices, combined_test_indices):
166170
train_split_data, test_split_data = [], []
167171
for train_split, test_split in zip(combined_train_indices, combined_test_indices):
168-
train_split_data.append(self.df.iloc[train_split, :])
169-
test_split_data.append(self.df.iloc[test_split, :])
172+
train_split_data.append(
173+
self.df.iloc[train_split, :].reset_index(drop=True)
174+
)
175+
test_split_data.append(
176+
self.df.iloc[test_split, :].reset_index(drop=True)
177+
)
170178
return train_split_data, test_split_data
171179

172180
def get_random_df_split_data(self):
@@ -192,6 +200,7 @@ def plot_distributions(self):
192200
nrows=4, ncols=self.n_cv,
193201
constrained_layout=True
194202
)
203+
logger.info("Plotting distributions...")
195204
fig.set_figwidth(30)
196205
fig.set_figheight(10)
197206

@@ -234,7 +243,7 @@ def plot_distributions(self):
234243
axs[i_category + 1, i_split].set_ylim(0, 20)
235244
axs[i_category + 1, i_split].set_xlim(self.min_pos - 4, self.max_pos + 4)
236245
axs[0, self.n_cv // 2].set_xticks(xticks)
237-
fig_path = path.abspath(path.splitext(path.basename(self.csv_file))[0] + '_pos_aa_distr.png')
238-
plt.savefig(fig_path, dpi=300)
239-
logger.info(f"Saved figure as {fig_path}.")
246+
247+
plt.savefig(self.fig_path, dpi=300)
248+
logger.info(f"Saved figure as {self.fig_path}.")
240249
plt.close(fig)

pypef/utils/utils_run.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,37 @@ def run_pypef_utils(arguments):
6060
arguments['--random'], arguments['--modulo'],
6161
arguments['--cont'], arguments['--plot']
6262
]:
63-
ds = DatasetSplitter(df)
63+
ds = DatasetSplitter(df, mutation_separator=arguments['--mutation_sep'])
6464
if arguments['--random']:
6565
train_data, test_data = ds.get_random_df_split_data()
66-
print(train_data)
66+
cv_technique = 'random'
6767
elif arguments['--modulo']:
6868
train_data, test_data = ds.get_modulo_df_split_data()
69-
print(train_data)
69+
cv_technique = 'modulo'
7070
elif arguments['--cont']:
7171
train_data, test_data = ds.get_continuous_df_split_data()
72-
print(train_data)
72+
cv_technique = 'continuous'
7373
elif arguments['--plot']:
7474
ds.print_shapes()
7575
ds.plot_distributions()
76+
if not arguments['--plot']:
77+
for i_cv, (train_set, test_set) in enumerate(zip(train_data, test_data)):
78+
single_variants_train, single_values_train, _, _ = get_variants(
79+
train_set, amino_acids, wt_sequence,
80+
arguments['--mutation_sep'], verbose=False
81+
)
82+
single_variants_test, single_values_test, _, _ = get_variants(
83+
test_set, amino_acids, wt_sequence,
84+
arguments['--mutation_sep'], verbose=False
85+
)
86+
make_fasta_ls_ts(
87+
f'LS_{cv_technique}_{i_cv + 1 }.fasl', wt_sequence,
88+
single_variants_train, single_values_train
89+
)
90+
make_fasta_ls_ts(
91+
f'TS_{cv_technique}_{i_cv + 1 }.fasl', wt_sequence,
92+
single_variants_test, single_values_test
93+
)
7694
else:
7795
sub_ls, val_ls, sub_ts, val_ts = make_sub_ls_ts(
7896
single_variants, single_values,
@@ -96,8 +114,8 @@ def run_pypef_utils(arguments):
96114
higher_variants, higher_values,
97115
ls_proportion
98116
)
99-
make_fasta_ls_ts('LS_random_' + str(random_set_counter) + '.fasl', wt_sequence, sub_ls, val_ls)
100-
make_fasta_ls_ts('TS_random_' + str(random_set_counter) + '.fasl', wt_sequence, sub_ts, val_ts)
117+
make_fasta_ls_ts('LS_default_random_' + str(random_set_counter) + '.fasl', wt_sequence, sub_ls, val_ls)
118+
make_fasta_ls_ts('TS_default_random_' + str(random_set_counter) + '.fasl', wt_sequence, sub_ts, val_ts)
101119
random_set_counter += 1
102120

103121
elif arguments['mkps']:

0 commit comments

Comments
 (0)