Skip to content

Commit d1d7b68

Browse files
committed
Dev: add mklsts [--random] [--modulo] [--cont] [--plot] flags
1 parent e4dc48d commit d1d7b68

File tree

3 files changed

+85
-29
lines changed

3 files changed

+85
-29
lines changed

pypef/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
112112
Usage:
113113
pypef mklsts --wt WT_FASTA --input CSV_FILE
114+
[--random] [--modulo] [--cont] [--plot]
114115
[--drop THRESHOLD] [--sep CSV_COLUMN_SEPARATOR] [--mutation_sep MUTATION_SEPARATOR]
115116
[--numrnd NUMBER] [--ls_proportion LS_PROPORTION]
116117
pypef mkps --wt WT_FASTA [--input CSV_FILE]
@@ -303,6 +304,7 @@
303304
schema = Schema({
304305
Optional('--all'): bool,
305306
Optional('--conc'): bool,
307+
Optional('--cont'): bool,
306308
Optional('--csvaa'): bool,
307309
Optional('--ddiverse'): bool,
308310
Optional('--drecomb'): bool,
@@ -319,6 +321,7 @@
319321
Optional('--ls'): Or(None, str),
320322
Optional('--ls_proportion'): Or(None, Use(float)),
321323
Optional('--model'): Or(None, str),
324+
Optional('--modulo'): bool,
322325
Optional('--msa'): Or(None, str),
323326
Optional('--mutation_sep'): Or(None, str),
324327
Optional('--negative'): bool,
@@ -331,7 +334,9 @@
331334
Optional('--params'): Or(None, str),
332335
Optional('--pdb'): Or(None, str),
333336
Optional('--pmult'): bool,
337+
Optional('--plot'): bool,
334338
Optional('--ps'): Or(None, str),
339+
Optional('--random'): bool,
335340
Optional('--qdiverse'): bool,
336341
Optional('--qarecomb'): bool,
337342
Optional('--qirecomb'): bool,

pypef/utils/split.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,25 @@ def __init__(
1717
self,
1818
df_or_csv_file: str | PathLike | pd.DataFrame,
1919
n_cv: int | None = None,
20-
mutation_column: str | None = None,
21-
separator: str | None = None
20+
mutation_column: str | None = None,
21+
mutation_separator: str | None = None,
22+
csv_separator: str | None = None
2223
):
23-
if mutation_column is None:
24-
mutation_column = 'mutant'
2524
self.mutation_column = mutation_column
26-
if separator is None:
27-
separator = ','
28-
self.separator = separator
25+
if csv_separator is None:
26+
csv_separator = ','
27+
if mutation_separator is None:
28+
mutation_separator = '/'
29+
self.mutation_separator = mutation_separator
30+
self.csv_separator = csv_separator
2931
if n_cv is None:
3032
n_cv = 5
3133
self.n_cv = n_cv
3234
if type(df_or_csv_file) == pd.DataFrame:
3335
self.df = df_or_csv_file
3436
else:
35-
self.df = pd.read_csv(self.csv_file, sep=self.separator)
37+
self.df = pd.read_csv(self.df_or_csv_file, sep=self.csv_separator)
38+
print(f'Dataframe size: {self.df.shape[0]}')
3639
self.random_splits_train_indices_combined, self.random_splits_test_indices_combined = None, None
3740
self.modulo_splits_train_indices_combined, self.modulo_splits_test_indices_combined = None, None
3841
self.cont_splits_train_indices_combined, self.cont_splits_test_indices_combined = None, None
@@ -43,9 +46,23 @@ def __init__(
4346

4447
def order_by_pos(self):
4548
if self.mutation_column is None:
46-
self.mutation_column = 'mutant'
47-
variants = self.df[self.mutation_column].to_list()
48-
self.df['variant_pos'] = [int(v[1:-1]) for v in variants]
49+
variants = self.df.iloc[:, 0].to_list()
50+
else:
51+
variants = self.df[self.mutation_column].to_list()
52+
single_mut_idxs = []
53+
for i, variant in enumerate(variants):
54+
if not self.mutation_separator in variant:
55+
single_mut_idxs.append(i)
56+
if single_mut_idxs:
57+
self.df = self.df.loc[single_mut_idxs, :]
58+
if len(single_mut_idxs) != self.df.size:
59+
print(f'Removed multimutated variants from dataframe... '
60+
f'new dataframe size: {self.df.shape[0]}')
61+
if self.mutation_column is None:
62+
variants = self.df.iloc[:, 0].to_list()
63+
else:
64+
variants = self.df[self.mutation_column].to_list()
65+
self.df.loc[:, 'variant_pos'] = [int(v[1:-1]) for v in variants]
4966
self.df['substitutions'] = [v[-1] for v in variants]
5067
self.df.sort_values(['variant_pos', 'substitutions'], ascending=[True, True], inplace=True)
5168
self.min_pos, self.max_pos = self.df['variant_pos'].to_numpy()[0], self.df['variant_pos'].to_numpy()[-1]
@@ -144,6 +161,31 @@ def get_all_split_indices(self):
144161
[self.modulo_splits_train_indices_combined, self.modulo_splits_test_indices_combined],
145162
[self.cont_splits_train_indices_combined, self.cont_splits_test_indices_combined]
146163
]
164+
165+
def _get_df_split_data(self, combined_train_indices, combined_test_indices):
166+
train_split_data, test_split_data = [], []
167+
for train_split, test_split in zip(combined_train_indices, combined_test_indices):
168+
train_split_data.append(self.df.iloc[train_split, :])
169+
test_split_data.append(self.df.iloc[test_split, :])
170+
return train_split_data, test_split_data
171+
172+
def get_random_df_split_data(self):
173+
return self._get_df_split_data(
174+
self.random_splits_train_indices_combined,
175+
self.random_splits_test_indices_combined
176+
)
177+
178+
def get_modulo_df_split_data(self):
179+
return self._get_df_split_data(
180+
self.modulo_splits_train_indices_combined,
181+
self.modulo_splits_test_indices_combined
182+
)
183+
184+
def get_continuous_df_split_data(self):
185+
return self._get_df_split_data(
186+
self.cont_splits_train_indices_combined,
187+
self.cont_splits_test_indices_combined
188+
)
147189

148190
def plot_distributions(self):
149191
fig, axs = plt.subplots(

pypef/utils/utils_run.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,19 @@
3434

3535

3636
def run_pypef_utils(arguments):
37-
if arguments['mklsts'] or ['mklsts_rnd'] or ['mklsts_mod'] or ['mklsts_cont'] or ['mklsts_plot']:
37+
if arguments['mklsts']:
3838
wt_sequence = get_wt_sequence(arguments['--wt'])
3939
t_drop = float(arguments['--drop'])
4040
ls_proportion = arguments['--ls_proportion']
4141
logger.info(f'Length of provided sequence: {len(wt_sequence)} amino acids.')
42-
logger.info(f'Training set proportion (--ls_proportion): {ls_proportion}.')
42+
if True in [
43+
arguments['--random'], arguments['--modulo'],
44+
arguments['--cont'], arguments['--plot']
45+
]:
46+
logger.info(f'Ignoring set proportion (--ls_proportion).')
47+
else:
48+
logger.info(f'Training set proportion (--ls_proportion): {ls_proportion}.')
49+
4350
df = drop_rows(arguments['--input'], amino_acids, t_drop,
4451
arguments['--sep'], arguments['--mutation_sep'])
4552
no_rnd = arguments['--numrnd']
@@ -49,18 +56,33 @@ def run_pypef_utils(arguments):
4956
if len(single_variants) == 0:
5057
logger.info('Found no single substitution variants for possible recombination!')
5158

52-
if arguments['mklsts']:
59+
if True in [
60+
arguments['--random'], arguments['--modulo'],
61+
arguments['--cont'], arguments['--plot']
62+
]:
63+
ds = DatasetSplitter(df)
64+
if arguments['--random']:
65+
train_data, test_data = ds.get_random_df_split_data()
66+
print(train_data)
67+
elif arguments['--modulo']:
68+
train_data, test_data = ds.get_modulo_df_split_data()
69+
print(train_data)
70+
elif arguments['--cont']:
71+
train_data, test_data = ds.get_continuous_df_split_data()
72+
print(train_data)
73+
elif arguments['--plot']:
74+
ds.print_shapes()
75+
ds.plot_distributions()
76+
else:
5377
sub_ls, val_ls, sub_ts, val_ts = make_sub_ls_ts(
5478
single_variants, single_values,
5579
higher_variants, higher_values,
5680
ls_proportion
5781
)
5882
logger.info('Tip: You can edit your LS and TS datasets just by '
5983
'cutting/pasting between the LS and TS fasta datasets.')
60-
6184
make_fasta_ls_ts('LS.fasl', wt_sequence, sub_ls, val_ls)
6285
make_fasta_ls_ts('TS.fasl', wt_sequence, sub_ts, val_ts)
63-
6486
try:
6587
no_rnd = int(no_rnd)
6688
except ValueError:
@@ -77,19 +99,6 @@ def run_pypef_utils(arguments):
7799
make_fasta_ls_ts('LS_random_' + str(random_set_counter) + '.fasl', wt_sequence, sub_ls, val_ls)
78100
make_fasta_ls_ts('TS_random_' + str(random_set_counter) + '.fasl', wt_sequence, sub_ts, val_ts)
79101
random_set_counter += 1
80-
else:
81-
ds = DatasetSplitter(df)
82-
if arguments['mklsts_rnd']:
83-
pass # TODO
84-
85-
elif arguments['mklsts_mod']:
86-
pass # TODO
87-
88-
elif arguments['mklsts_cont']:
89-
pass # TODO
90-
91-
elif arguments['mklsts_plot']:
92-
pass # TODO
93102

94103
elif arguments['mkps']:
95104
wt_sequence = get_wt_sequence(arguments['--wt'])

0 commit comments

Comments
 (0)