Skip to content

Commit a12b680

Browse files
committed
transitioned train_test_validate to use the two way split internally
1 parent d7a5fa9 commit a12b680

File tree

1 file changed

+81
-261
lines changed

1 file changed

+81
-261
lines changed

coderdata/dataset/dataset.py

Lines changed: 81 additions & 261 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,31 @@ def split_train_other(
357357
)
358358

359359
return split
360+
361+
362+
def split_train_test_validate(
363+
self,
364+
split_type: Literal[
365+
'mixed-set', 'drug-blind', 'cancer-blind'
366+
]='mixed-set',
367+
ratio: tuple[int, int, int]=(8,1,1),
368+
stratify_by: Optional[str]=None,
369+
random_state: Optional[Union[int,RandomState]]=None,
370+
**kwargs: dict,
371+
) -> Split:
372+
373+
split = split_train_test_validate(
374+
data=self,
375+
split_type=split_type,
376+
ratio=ratio,
377+
stratify_by=stratify_by,
378+
random_state=random_state,
379+
**kwargs
380+
)
381+
382+
return split
383+
384+
360385
def train_test_validate(
361386
self,
362387
split_type: Literal[
@@ -368,7 +393,7 @@ def train_test_validate(
368393
**kwargs: dict,
369394
) -> Split:
370395

371-
split = train_test_validate(
396+
split = split_train_test_validate(
372397
data=self,
373398
split_type=split_type,
374399
ratio=ratio,
@@ -715,6 +740,53 @@ def split_train_other(
715740
train.experiments = train.experiments[train.experiments['dose_response_metric'] != 'split_class']
716741
other.experiments = other.experiments[other.experiments['dose_response_metric'] != 'split_class']
717742
return TwoWaySplit(train=train, other=other)
743+
744+
745+
def split_train_test_validate(
746+
data: Dataset,
747+
split_type: Literal[
748+
'mixed-set', 'drug-blind', 'cancer-blind'
749+
]='mixed-set',
750+
ratio: tuple[int, int, int]=(8,1,1),
751+
stratify_by: Optional[str]=None,
752+
random_state: Optional[Union[int,RandomState]]=None,
753+
**kwargs: dict,
754+
) -> Split:
755+
756+
# Type checking split_type
757+
if split_type not in [
758+
'mixed-set', 'drug-blind', 'cancer-blind'
759+
]:
760+
raise ValueError(
761+
f"{split_type} not an excepted input for 'split_type'"
762+
)
763+
764+
train, other = _split_two_way(
765+
data=data,
766+
split_type=split_type,
767+
ratio=[ratio[0], ratio[1] + ratio[2]],
768+
stratify_by=stratify_by,
769+
random_state=random_state,
770+
kwargs=kwargs,
771+
)
772+
773+
test, val = _split_two_way(
774+
data=other,
775+
split_type=split_type,
776+
ratio=[ratio[1], ratio[2]],
777+
stratify_by=stratify_by,
778+
random_state=random_state,
779+
kwargs=kwargs,
780+
)
781+
782+
if stratify_by is not None:
783+
train.experiments = train.experiments[train.experiments['dose_response_metric'] != 'split_class']
784+
test.experiments = test.experiments[test.experiments['dose_response_metric'] != 'split_class']
785+
val.experiments = val.experiments[val.experiments['dose_response_metric'] != 'split_class']
786+
787+
return Split(train=train, test=test, validate=val)
788+
789+
718790
def train_test_validate(
719791
data: Dataset,
720792
split_type: Literal[
@@ -794,266 +866,14 @@ def train_test_validate(
794866
795867
"""
796868

797-
# reading in the potential keyword arguments that will be passed to
798-
# _create_classes().
799-
thresh = kwargs.get('thresh', None)
800-
num_classes = kwargs.get('num_classes', 2)
801-
quantiles = kwargs.get('quantiles', True)
802-
803-
# Type checking split_type
804-
if split_type not in [
805-
'mixed-set', 'drug-blind', 'cancer-blind'
806-
]:
807-
raise ValueError(
808-
f"{split_type} not an excepted input for 'split_type'"
809-
)
810-
811-
# A wide (pivoted) table is more easy to work with in this instance.
812-
# The pivot is done using all columns but the 'dose_respones_value'
813-
# and 'dose_respones_metric' as index. df.pivot will generate a
814-
# MultiIndex which complicates things further down the line. To that
815-
# end 'reset_index()' is used to remove the MultiIndex
816-
df_full = data.experiments.copy()
817-
df_full = df_full.pivot(
818-
index = [
819-
'source',
820-
'improve_sample_id',
821-
'improve_drug_id',
822-
'study',
823-
'time',
824-
'time_unit'
825-
],
826-
columns = 'dose_response_metric',
827-
values = 'dose_response_value'
828-
).reset_index()
829-
830-
# Defining the split sizes.
831-
train_size = float(ratio[0]) / sum(ratio)
832-
test_val_size = float(ratio[1] + ratio[2]) / sum(ratio)
833-
test_size = float(ratio[1]) / (ratio[1] + ratio[2])
834-
validate_size = 1 - test_size
835-
836-
# ShuffleSplit is a method/class implemented by scikit-learn that
837-
# enables creating splits where the data is shuffled and then
838-
# randomly distributed into train and test sets according to the
839-
# defined ratio.
840-
#
841-
# n_splits defines how often a train/test split is generated.
842-
# Individual splits (if more than 1 is generated) are not guaranteed
843-
# to be disjoint i.e. test sets from individual splits can overlap.
844-
#
845-
# ShuffleSplit will be used for non stratified mixed-set splitting
846-
# since there is no requirement for disjoint groups (i.e. drug /
847-
# sample ids).
848-
shs_1 = ShuffleSplit(
849-
n_splits=1,
850-
train_size=train_size,
851-
test_size=test_val_size,
852-
random_state=random_state
853-
)
854-
shs_2 = ShuffleSplit(
855-
n_splits=1,
856-
train_size=test_size,
857-
test_size=validate_size,
858-
random_state=random_state
859-
)
860-
861-
# GroupShuffleSplit is an extension to ShuffleSplit that also
862-
# factors in a group that is used to generate disjoint train and
863-
# test sets, e.g. in this particular case the drug or sample id to
864-
# generate drug-blind or sample-blind train and test sets.
865-
#
866-
# GroupShuffleSplit will be used for non stratified drug-/sample-
867-
# blind splitting, i.e. there is a requirement that instances from
868-
# one group (e.g. a specific drug) are only present in the training
869-
# set but not in the test set.
870-
gss_1 = GroupShuffleSplit(
871-
n_splits=1,
872-
train_size=train_size,
873-
test_size=test_val_size,
874-
random_state=random_state
875-
)
876-
gss_2 = GroupShuffleSplit(
877-
n_splits=1,
878-
train_size=test_size,
879-
test_size=validate_size,
880-
random_state=random_state
881-
)
882-
883-
# StratifiedShuffleSplit is similar to ShuffleSplit with the added
884-
# functionality to also stratify the splits according to defined
885-
# class labels.
886-
#
887-
# StratifiedShuffleSplit will be used for stratified mixed-set
888-
# train/test/validate sets.
889-
890-
sss_1 = StratifiedShuffleSplit(
891-
n_splits=1,
892-
train_size=train_size,
893-
test_size=test_val_size,
894-
random_state=random_state
895-
)
896-
sss_2 = StratifiedShuffleSplit(
897-
n_splits=1,
898-
train_size=test_size,
899-
test_size=validate_size,
900-
random_state=random_state
901-
)
902-
903-
# StratifiedGroupKFold generates K folds that take the group into
904-
# account when generating folds, i.e. a group will only be present
905-
# in one fold. It further tries to stratify the folds based on the
906-
# defined classes.
907-
#
908-
# StratifiedGroupKFold will be used for stratified drug-/sample-
909-
# blind splitting.
910-
#
911-
# The way the K folds are utilized is to combine i, j, & k folds
912-
# (according to the defined ratio) into training, testing and
913-
# validation sets.
914-
sgk = StratifiedGroupKFold(
915-
n_splits=sum(ratio),
916-
shuffle=True,
917-
random_state=random_state
918-
)
919-
920-
# The "actual" splitting logic using the defined Splitters as above
921-
# follows here starting with the non-stratified splitting:
922-
if stratify_by is None:
923-
if split_type == 'mixed-set':
924-
# Using ShuffleSplit to generate randomized train and
925-
# 'other' set, since there is no need for grouping.
926-
idx1, idx2 = next(
927-
shs_1.split(df_full)
928-
)
929-
elif split_type == 'drug-blind':
930-
# Using GroupShuffleSplit to created disjoint train and
931-
# 'other' sets by drug id
932-
idx1, idx2 = next(
933-
gss_1.split(df_full, groups=df_full.improve_drug_id)
934-
)
935-
elif split_type == 'cancer-blind':
936-
# same as above we just group over the sample id
937-
idx1, idx2 = next(
938-
gss_1.split(df_full, groups=df_full.improve_sample_id)
939-
)
940-
else:
941-
raise Exception(f"Should be unreachable")
942-
943-
# generate new DFs containing the subset of items extracted for
944-
# train and other
945-
df_train = df_full.iloc[idx1]
946-
df_other = df_full.iloc[idx2]
947-
948-
# follows same logic as previous splitting with the difference
949-
# that only "other" is sampled and split
950-
if split_type == 'mixed-set':
951-
idx1, idx2 = next(
952-
shs_2.split(df_other, groups=None)
953-
)
954-
elif split_type == 'drug-blind':
955-
idx1, idx2 = next(
956-
gss_2.split(df_other, groups=df_other.improve_drug_id)
957-
)
958-
elif split_type == 'cancer-blind':
959-
idx1, idx2 = next(
960-
gss_2.split(df_other, groups=df_other.improve_sample_id)
961-
)
962-
else:
963-
raise Exception(f"Should be unreachable")
964-
965-
# extract itmes for test and validate from other based on the
966-
# sampled indices
967-
df_test = df_other.iloc[idx1]
968-
df_val = df_other.iloc[idx2]
969-
970-
# The following block contains the stratified splitting logic
971-
else:
972-
# First the classes that are needed for the stratification are
973-
# generated. `num_classes`, `thresh` and `quantiles` were
974-
# previously defined as possible keyword arguments.
975-
df_full = _create_classes(
976-
data=df_full,
977-
metric=stratify_by,
978-
num_classes=num_classes,
979-
thresh=thresh,
980-
quantiles=quantiles,
981-
)
982-
if split_type == 'mixed-set':
983-
# Using StratifiedShuffleSplit to generate randomized train
984-
# and 'other' set, since there is no need for grouping.
985-
idx_train, idx_other = next(
986-
sss_1.split(X=df_full, y=df_full['split_class'])
987-
)
988-
df_train = df_full.iloc[idx_train]
989-
df_train = df_train.drop(labels=['split_class'], axis=1)
990-
df_other = df_full.iloc[idx_other]
991-
# Splitting 'other' further into test and validate
992-
idx_test, idx_val = next(
993-
sss_2.split(X=df_other, y=df_other['split_class'])
994-
)
995-
df_test = df_other.iloc[idx_test]
996-
df_test = df_test.drop(labels=['split_class'], axis=1)
997-
df_val = df_other.iloc[idx_val]
998-
df_val = df_val.drop(labels=['split_class'], axis=1)
999-
1000-
# using StratifiedGroupKSplit for the stratified drug-/sample-
1001-
# blind splits.
1002-
elif split_type == 'drug-blind' or split_type == 'cancer-blind':
1003-
if split_type == 'drug-blind':
1004-
splitter = enumerate(
1005-
sgk.split(
1006-
X=df_full,
1007-
y=df_full['split_class'],
1008-
groups=df_full.improve_drug_id
1009-
)
1010-
)
1011-
elif split_type == 'cancer-blind':
1012-
splitter = enumerate(
1013-
sgk.split(
1014-
X=df_full,
1015-
y=df_full['split_class'],
1016-
groups=df_full.improve_sample_id
1017-
)
1018-
)
1019-
1020-
# StratifiedGroupKSplit is setup to generate K splits where
1021-
# K=sum(ratios) (e.g. 10 if ratio=8:1:1). To obtain three
1022-
# sets (train/test/validate) the individual splits need to
1023-
# be combined (e.g. k=[1:8] -> train, k=9 -> test, k=10 ->
1024-
# validate). The code block below does that by combining
1025-
# all indices (row numbers) that go into individual sets and
1026-
# then extracting and adding those rows into the individual
1027-
# sets.
1028-
idx_train = []
1029-
idx_test = []
1030-
idx_val = []
1031-
for i, (idx1, idx2) in splitter:
1032-
if i < ratio[0]:
1033-
idx_train.extend(idx2)
1034-
elif i >= ratio[0] and i < (ratio[0] + ratio[1]):
1035-
idx_test.extend(idx2)
1036-
elif (
1037-
i >= (ratio[0] + ratio[1])
1038-
and i < (ratio[0] + ratio[1] + ratio[2])
1039-
):
1040-
idx_val.extend(idx2)
1041-
df_full.drop(labels=['split_class'], axis=1, inplace=True)
1042-
df_train = df_full.iloc[idx_train]
1043-
df_test = df_full.iloc[idx_test]
1044-
df_val = df_full.iloc[idx_val]
1045-
else:
1046-
raise Exception(f"Should be unreachable")
1047-
1048-
1049-
# generating filtered CoderData objects that contain only the
1050-
# respective data for each split
1051-
data_train = _filter(data, df_train)
1052-
data_test = _filter(data, df_test)
1053-
data_val = _filter(data, df_val)
1054-
1055-
return Split(data_train, data_test, data_val)
1056-
869+
return split_train_test_validate(
870+
data=data,
871+
split_type=split_type,
872+
ratio=ratio,
873+
stratify_by=stratify_by,
874+
random_state=random_state,
875+
kwargs=kwargs,
876+
)
1057877

1058878

1059879
def _load_file(file_path: Path) -> pd.DataFrame:

0 commit comments

Comments
 (0)