@@ -17,22 +17,25 @@ def __init__(
1717 self ,
1818 df_or_csv_file : str | PathLike | pd .DataFrame ,
1919 n_cv : int | None = None ,
20- mutation_column : str | None = None ,
21- separator : str | None = None
20+ mutation_column : str | None = None ,
21+ mutation_separator : str | None = None ,
22+ csv_separator : str | None = None
2223 ):
23- if mutation_column is None :
24- mutation_column = 'mutant'
2524 self .mutation_column = mutation_column
26- if separator is None :
27- separator = ','
28- self .separator = separator
25+ if csv_separator is None :
26+ csv_separator = ','
27+ if mutation_separator is None :
28+ mutation_separator = '/'
29+ self .mutation_separator = mutation_separator
30+ self .csv_separator = csv_separator
2931 if n_cv is None :
3032 n_cv = 5
3133 self .n_cv = n_cv
3234 if type (df_or_csv_file ) == pd .DataFrame :
3335 self .df = df_or_csv_file
3436 else :
35- self .df = pd .read_csv (self .csv_file , sep = self .separator )
37+ self .df = pd .read_csv (self .df_or_csv_file , sep = self .csv_separator )
38+ print (f'Dataframe size: { self .df .shape [0 ]} ' )
3639 self .random_splits_train_indices_combined , self .random_splits_test_indices_combined = None , None
3740 self .modulo_splits_train_indices_combined , self .modulo_splits_test_indices_combined = None , None
3841 self .cont_splits_train_indices_combined , self .cont_splits_test_indices_combined = None , None
@@ -43,9 +46,23 @@ def __init__(
4346
4447 def order_by_pos (self ):
4548 if self .mutation_column is None :
46- self .mutation_column = 'mutant'
47- variants = self .df [self .mutation_column ].to_list ()
48- self .df ['variant_pos' ] = [int (v [1 :- 1 ]) for v in variants ]
49+ variants = self .df .iloc [:, 0 ].to_list ()
50+ else :
51+ variants = self .df [self .mutation_column ].to_list ()
52+ single_mut_idxs = []
53+ for i , variant in enumerate (variants ):
54+ if not self .mutation_separator in variant :
55+ single_mut_idxs .append (i )
56+ if single_mut_idxs :
57+ self .df = self .df .loc [single_mut_idxs , :]
58+ if len (single_mut_idxs ) != self .df .size :
59+ print (f'Removed multimutated variants from dataframe... '
60+ f'new dataframe size: { self .df .shape [0 ]} ' )
61+ if self .mutation_column is None :
62+ variants = self .df .iloc [:, 0 ].to_list ()
63+ else :
64+ variants = self .df [self .mutation_column ].to_list ()
65+ self .df .loc [:, 'variant_pos' ] = [int (v [1 :- 1 ]) for v in variants ]
4966 self .df ['substitutions' ] = [v [- 1 ] for v in variants ]
5067 self .df .sort_values (['variant_pos' , 'substitutions' ], ascending = [True , True ], inplace = True )
5168 self .min_pos , self .max_pos = self .df ['variant_pos' ].to_numpy ()[0 ], self .df ['variant_pos' ].to_numpy ()[- 1 ]
@@ -144,6 +161,31 @@ def get_all_split_indices(self):
144161 [self .modulo_splits_train_indices_combined , self .modulo_splits_test_indices_combined ],
145162 [self .cont_splits_train_indices_combined , self .cont_splits_test_indices_combined ]
146163 ]
164+
165+ def _get_df_split_data (self , combined_train_indices , combined_test_indices ):
166+ train_split_data , test_split_data = [], []
167+ for train_split , test_split in zip (combined_train_indices , combined_test_indices ):
168+ train_split_data .append (self .df .iloc [train_split , :])
169+ test_split_data .append (self .df .iloc [test_split , :])
170+ return train_split_data , test_split_data
171+
172+ def get_random_df_split_data (self ):
173+ return self ._get_df_split_data (
174+ self .random_splits_train_indices_combined ,
175+ self .random_splits_test_indices_combined
176+ )
177+
178+ def get_modulo_df_split_data (self ):
179+ return self ._get_df_split_data (
180+ self .modulo_splits_train_indices_combined ,
181+ self .modulo_splits_test_indices_combined
182+ )
183+
184+ def get_continuous_df_split_data (self ):
185+ return self ._get_df_split_data (
186+ self .cont_splits_train_indices_combined ,
187+ self .cont_splits_test_indices_combined
188+ )
147189
148190 def plot_distributions (self ):
149191 fig , axs = plt .subplots (
0 commit comments