Skip to content

Commit c0e3c03

Browse files
committed
added basic blancing logic - currently 4 evenly spaced bins
1 parent fb709c0 commit c0e3c03

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

scripts/prepare_data_for_improve.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ def main():
8282
type=int,
8383
default=10
8484
)
85+
p_process_datasets.add_argument(
86+
'-b', '--balance_by', dest='BALANCE_BY',
87+
choices=['auc', 'fit_auc'],
88+
default=None,
89+
help="Defines if and using which drug response metric the splits "
90+
"should be balanced by."
91+
)
8592
p_process_datasets.add_argument(
8693
'-r', '--random_seeds', dest='RANDOM_SEEDS',
8794
type=_random_seed_list,
@@ -166,7 +173,7 @@ def process_datasets(args):
166173
logger.debug("creating list of datasets that contain experiment info ...")
167174
for data_set in data_sets_names_list:
168175
# sarcpdo has different drug response values
169-
if data_set == 'sarcpdo':
176+
if data_set == 'sarcpdo' and data_sets[data_set].experiments is not None:
170177
experiment = data_sets[data_set].format(
171178
data_type='experiments',
172179
shape='wide',
@@ -763,13 +770,21 @@ def split_data_sets(
763770
args: dict,
764771
data_sets: dict,
765772
data_sets_names: list,
766-
response_data: pd.DataFrame
773+
response_data: pd.DataFrame,
767774
):
768775

769776
splits_folder = args.WORKDIR.joinpath('data_out', 'splits')
770777
split_type = args.SPLIT_TYPE
771778
ratio = (8,1,1)
772-
stratify_by = None
779+
stratify_by = args.BALANCE_BY
780+
if stratify_by is not None:
781+
balance = True
782+
quantiles = False
783+
num_classes = 4
784+
else:
785+
balance = False
786+
quantiles = True
787+
num_classes = 4
773788
if args.RANDOM_SEEDS is not None:
774789
random_seeds = args.RANDOM_SEEDS
775790
else:
@@ -818,6 +833,9 @@ def split_data_sets(
818833
split_type=split_type,
819834
ratio=ratio,
820835
stratify_by=stratify_by,
836+
balance=balance,
837+
quantiles=quantiles,
838+
num_classes=num_classes,
821839
random_state=random_seeds[i]
822840
)
823841
train_keys = (

0 commit comments

Comments
 (0)