@@ -82,6 +82,13 @@ def main():
8282 type = int ,
8383 default = 10
8484 )
85+ p_process_datasets .add_argument (
86+ '-b' , '--balance_by' , dest = 'BALANCE_BY' ,
87+ choices = ['auc' , 'fit_auc' ],
88+ default = None ,
89+ help = "Defines if and using which drug response metric the splits "
90+ "should be balanced by."
91+ )
8592 p_process_datasets .add_argument (
8693 '-r' , '--random_seeds' , dest = 'RANDOM_SEEDS' ,
8794 type = _random_seed_list ,
@@ -166,7 +173,7 @@ def process_datasets(args):
166173 logger .debug ("creating list of datasets that contain experiment info ..." )
167174 for data_set in data_sets_names_list :
168175 # sarcpdo has different drug response values
169- if data_set == 'sarcpdo' :
176+ if data_set == 'sarcpdo' and data_sets [ data_set ]. experiments is not None :
170177 experiment = data_sets [data_set ].format (
171178 data_type = 'experiments' ,
172179 shape = 'wide' ,
@@ -763,13 +770,21 @@ def split_data_sets(
763770 args : dict ,
764771 data_sets : dict ,
765772 data_sets_names : list ,
766- response_data : pd .DataFrame
773+ response_data : pd .DataFrame ,
767774 ):
768775
769776 splits_folder = args .WORKDIR .joinpath ('data_out' , 'splits' )
770777 split_type = args .SPLIT_TYPE
771778 ratio = (8 ,1 ,1 )
772- stratify_by = None
779+ stratify_by = args .BALANCE_BY
780+ if stratify_by is not None :
781+ balance = True
782+ quantiles = False
783+ num_classes = 4
784+ else :
785+ balance = False
786+ quantiles = True
787+ num_classes = 4
773788 if args .RANDOM_SEEDS is not None :
774789 random_seeds = args .RANDOM_SEEDS
775790 else :
@@ -818,6 +833,9 @@ def split_data_sets(
818833 split_type = split_type ,
819834 ratio = ratio ,
820835 stratify_by = stratify_by ,
836+ balance = balance ,
837+ quantiles = quantiles ,
838+ num_classes = num_classes ,
821839 random_state = random_seeds [i ]
822840 )
823841 train_keys = (
0 commit comments