@@ -357,6 +357,31 @@ def split_train_other(
357357 )
358358
359359 return split
360+
361+
362+ def split_train_test_validate (
363+ self ,
364+ split_type : Literal [
365+ 'mixed-set' , 'drug-blind' , 'cancer-blind'
366+ ]= 'mixed-set' ,
367+ ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
368+ stratify_by : Optional [str ]= None ,
369+ random_state : Optional [Union [int ,RandomState ]]= None ,
370+ ** kwargs : dict ,
371+ ) -> Split :
372+
373+ split = split_train_test_validate (
374+ data = self ,
375+ split_type = split_type ,
376+ ratio = ratio ,
377+ stratify_by = stratify_by ,
378+ random_state = random_state ,
379+ ** kwargs
380+ )
381+
382+ return split
383+
384+
360385 def train_test_validate (
361386 self ,
362387 split_type : Literal [
@@ -368,7 +393,7 @@ def train_test_validate(
368393 ** kwargs : dict ,
369394 ) -> Split :
370395
371- split = train_test_validate (
396+ split = split_train_test_validate (
372397 data = self ,
373398 split_type = split_type ,
374399 ratio = ratio ,
@@ -715,6 +740,53 @@ def split_train_other(
715740 train .experiments = train .experiments [train .experiments ['dose_response_metric' ] != 'split_class' ]
716741 other .experiments = other .experiments [other .experiments ['dose_response_metric' ] != 'split_class' ]
717742 return TwoWaySplit (train = train , other = other )
743+
744+
745+ def split_train_test_validate (
746+ data : Dataset ,
747+ split_type : Literal [
748+ 'mixed-set' , 'drug-blind' , 'cancer-blind'
749+ ]= 'mixed-set' ,
750+ ratio : tuple [int , int , int ]= (8 ,1 ,1 ),
751+ stratify_by : Optional [str ]= None ,
752+ random_state : Optional [Union [int ,RandomState ]]= None ,
753+ ** kwargs : dict ,
754+ ) -> Split :
755+
756+ # Type checking split_type
757+ if split_type not in [
758+ 'mixed-set' , 'drug-blind' , 'cancer-blind'
759+ ]:
760+ raise ValueError (
761+ f"{ split_type } not an excepted input for 'split_type'"
762+ )
763+
764+ train , other = _split_two_way (
765+ data = data ,
766+ split_type = split_type ,
767+ ratio = [ratio [0 ], ratio [1 ] + ratio [2 ]],
768+ stratify_by = stratify_by ,
769+ random_state = random_state ,
770+ kwargs = kwargs ,
771+ )
772+
773+ test , val = _split_two_way (
774+ data = other ,
775+ split_type = split_type ,
776+ ratio = [ratio [1 ], ratio [2 ]],
777+ stratify_by = stratify_by ,
778+ random_state = random_state ,
779+ kwargs = kwargs ,
780+ )
781+
782+ if stratify_by is not None :
783+ train .experiments = train .experiments [train .experiments ['dose_response_metric' ] != 'split_class' ]
784+ test .experiments = test .experiments [test .experiments ['dose_response_metric' ] != 'split_class' ]
785+ val .experiments = val .experiments [val .experiments ['dose_response_metric' ] != 'split_class' ]
786+
787+ return Split (train = train , test = test , validate = val )
788+
789+
718790def train_test_validate (
719791 data : Dataset ,
720792 split_type : Literal [
@@ -794,266 +866,14 @@ def train_test_validate(
794866
795867 """
796868
797- # reading in the potential keyword arguments that will be passed to
798- # _create_classes().
799- thresh = kwargs .get ('thresh' , None )
800- num_classes = kwargs .get ('num_classes' , 2 )
801- quantiles = kwargs .get ('quantiles' , True )
802-
803- # Type checking split_type
804- if split_type not in [
805- 'mixed-set' , 'drug-blind' , 'cancer-blind'
806- ]:
807- raise ValueError (
808- f"{ split_type } not an excepted input for 'split_type'"
809- )
810-
811- # A wide (pivoted) table is more easy to work with in this instance.
812- # The pivot is done using all columns but the 'dose_respones_value'
813- # and 'dose_respones_metric' as index. df.pivot will generate a
814- # MultiIndex which complicates things further down the line. To that
815- # end 'reset_index()' is used to remove the MultiIndex
816- df_full = data .experiments .copy ()
817- df_full = df_full .pivot (
818- index = [
819- 'source' ,
820- 'improve_sample_id' ,
821- 'improve_drug_id' ,
822- 'study' ,
823- 'time' ,
824- 'time_unit'
825- ],
826- columns = 'dose_response_metric' ,
827- values = 'dose_response_value'
828- ).reset_index ()
829-
830- # Defining the split sizes.
831- train_size = float (ratio [0 ]) / sum (ratio )
832- test_val_size = float (ratio [1 ] + ratio [2 ]) / sum (ratio )
833- test_size = float (ratio [1 ]) / (ratio [1 ] + ratio [2 ])
834- validate_size = 1 - test_size
835-
836- # ShuffleSplit is a method/class implemented by scikit-learn that
837- # enables creating splits where the data is shuffled and then
838- # randomly distributed into train and test sets according to the
839- # defined ratio.
840- #
841- # n_splits defines how often a train/test split is generated.
842- # Individual splits (if more than 1 is generated) are not guaranteed
843- # to be disjoint i.e. test sets from individual splits can overlap.
844- #
845- # ShuffleSplit will be used for non stratified mixed-set splitting
846- # since there is no requirement for disjoint groups (i.e. drug /
847- # sample ids).
848- shs_1 = ShuffleSplit (
849- n_splits = 1 ,
850- train_size = train_size ,
851- test_size = test_val_size ,
852- random_state = random_state
853- )
854- shs_2 = ShuffleSplit (
855- n_splits = 1 ,
856- train_size = test_size ,
857- test_size = validate_size ,
858- random_state = random_state
859- )
860-
861- # GroupShuffleSplit is an extension to ShuffleSplit that also
862- # factors in a group that is used to generate disjoint train and
863- # test sets, e.g. in this particular case the drug or sample id to
864- # generate drug-blind or sample-blind train and test sets.
865- #
866- # GroupShuffleSplit will be used for non stratified drug-/sample-
867- # blind splitting, i.e. there is a requirement that instances from
868- # one group (e.g. a specific drug) are only present in the training
869- # set but not in the test set.
870- gss_1 = GroupShuffleSplit (
871- n_splits = 1 ,
872- train_size = train_size ,
873- test_size = test_val_size ,
874- random_state = random_state
875- )
876- gss_2 = GroupShuffleSplit (
877- n_splits = 1 ,
878- train_size = test_size ,
879- test_size = validate_size ,
880- random_state = random_state
881- )
882-
883- # StratifiedShuffleSplit is similar to ShuffleSplit with the added
884- # functionality to also stratify the splits according to defined
885- # class labels.
886- #
887- # StratifiedShuffleSplit will be used for stratified mixed-set
888- # train/test/validate sets.
889-
890- sss_1 = StratifiedShuffleSplit (
891- n_splits = 1 ,
892- train_size = train_size ,
893- test_size = test_val_size ,
894- random_state = random_state
895- )
896- sss_2 = StratifiedShuffleSplit (
897- n_splits = 1 ,
898- train_size = test_size ,
899- test_size = validate_size ,
900- random_state = random_state
901- )
902-
903- # StratifiedGroupKFold generates K folds that take the group into
904- # account when generating folds, i.e. a group will only be present
905- # in one fold. It further tries to stratify the folds based on the
906- # defined classes.
907- #
908- # StratifiedGroupKFold will be used for stratified drug-/sample-
909- # blind splitting.
910- #
911- # The way the K folds are utilized is to combine i, j, & k folds
912- # (according to the defined ratio) into training, testing and
913- # validation sets.
914- sgk = StratifiedGroupKFold (
915- n_splits = sum (ratio ),
916- shuffle = True ,
917- random_state = random_state
918- )
919-
920- # The "actual" splitting logic using the defined Splitters as above
921- # follows here starting with the non-stratified splitting:
922- if stratify_by is None :
923- if split_type == 'mixed-set' :
924- # Using ShuffleSplit to generate randomized train and
925- # 'other' set, since there is no need for grouping.
926- idx1 , idx2 = next (
927- shs_1 .split (df_full )
928- )
929- elif split_type == 'drug-blind' :
930- # Using GroupShuffleSplit to created disjoint train and
931- # 'other' sets by drug id
932- idx1 , idx2 = next (
933- gss_1 .split (df_full , groups = df_full .improve_drug_id )
934- )
935- elif split_type == 'cancer-blind' :
936- # same as above we just group over the sample id
937- idx1 , idx2 = next (
938- gss_1 .split (df_full , groups = df_full .improve_sample_id )
939- )
940- else :
941- raise Exception (f"Should be unreachable" )
942-
943- # generate new DFs containing the subset of items extracted for
944- # train and other
945- df_train = df_full .iloc [idx1 ]
946- df_other = df_full .iloc [idx2 ]
947-
948- # follows same logic as previous splitting with the difference
949- # that only "other" is sampled and split
950- if split_type == 'mixed-set' :
951- idx1 , idx2 = next (
952- shs_2 .split (df_other , groups = None )
953- )
954- elif split_type == 'drug-blind' :
955- idx1 , idx2 = next (
956- gss_2 .split (df_other , groups = df_other .improve_drug_id )
957- )
958- elif split_type == 'cancer-blind' :
959- idx1 , idx2 = next (
960- gss_2 .split (df_other , groups = df_other .improve_sample_id )
961- )
962- else :
963- raise Exception (f"Should be unreachable" )
964-
965- # extract itmes for test and validate from other based on the
966- # sampled indices
967- df_test = df_other .iloc [idx1 ]
968- df_val = df_other .iloc [idx2 ]
969-
970- # The following block contains the stratified splitting logic
971- else :
972- # First the classes that are needed for the stratification are
973- # generated. `num_classes`, `thresh` and `quantiles` were
974- # previously defined as possible keyword arguments.
975- df_full = _create_classes (
976- data = df_full ,
977- metric = stratify_by ,
978- num_classes = num_classes ,
979- thresh = thresh ,
980- quantiles = quantiles ,
981- )
982- if split_type == 'mixed-set' :
983- # Using StratifiedShuffleSplit to generate randomized train
984- # and 'other' set, since there is no need for grouping.
985- idx_train , idx_other = next (
986- sss_1 .split (X = df_full , y = df_full ['split_class' ])
987- )
988- df_train = df_full .iloc [idx_train ]
989- df_train = df_train .drop (labels = ['split_class' ], axis = 1 )
990- df_other = df_full .iloc [idx_other ]
991- # Splitting 'other' further into test and validate
992- idx_test , idx_val = next (
993- sss_2 .split (X = df_other , y = df_other ['split_class' ])
994- )
995- df_test = df_other .iloc [idx_test ]
996- df_test = df_test .drop (labels = ['split_class' ], axis = 1 )
997- df_val = df_other .iloc [idx_val ]
998- df_val = df_val .drop (labels = ['split_class' ], axis = 1 )
999-
1000- # using StratifiedGroupKSplit for the stratified drug-/sample-
1001- # blind splits.
1002- elif split_type == 'drug-blind' or split_type == 'cancer-blind' :
1003- if split_type == 'drug-blind' :
1004- splitter = enumerate (
1005- sgk .split (
1006- X = df_full ,
1007- y = df_full ['split_class' ],
1008- groups = df_full .improve_drug_id
1009- )
1010- )
1011- elif split_type == 'cancer-blind' :
1012- splitter = enumerate (
1013- sgk .split (
1014- X = df_full ,
1015- y = df_full ['split_class' ],
1016- groups = df_full .improve_sample_id
1017- )
1018- )
1019-
1020- # StratifiedGroupKSplit is setup to generate K splits where
1021- # K=sum(ratios) (e.g. 10 if ratio=8:1:1). To obtain three
1022- # sets (train/test/validate) the individual splits need to
1023- # be combined (e.g. k=[1:8] -> train, k=9 -> test, k=10 ->
1024- # validate). The code block below does that by combining
1025- # all indices (row numbers) that go into individual sets and
1026- # then extracting and adding those rows into the individual
1027- # sets.
1028- idx_train = []
1029- idx_test = []
1030- idx_val = []
1031- for i , (idx1 , idx2 ) in splitter :
1032- if i < ratio [0 ]:
1033- idx_train .extend (idx2 )
1034- elif i >= ratio [0 ] and i < (ratio [0 ] + ratio [1 ]):
1035- idx_test .extend (idx2 )
1036- elif (
1037- i >= (ratio [0 ] + ratio [1 ])
1038- and i < (ratio [0 ] + ratio [1 ] + ratio [2 ])
1039- ):
1040- idx_val .extend (idx2 )
1041- df_full .drop (labels = ['split_class' ], axis = 1 , inplace = True )
1042- df_train = df_full .iloc [idx_train ]
1043- df_test = df_full .iloc [idx_test ]
1044- df_val = df_full .iloc [idx_val ]
1045- else :
1046- raise Exception (f"Should be unreachable" )
1047-
1048-
1049- # generating filtered CoderData objects that contain only the
1050- # respective data for each split
1051- data_train = _filter (data , df_train )
1052- data_test = _filter (data , df_test )
1053- data_val = _filter (data , df_val )
1054-
1055- return Split (data_train , data_test , data_val )
1056-
869+ return split_train_test_validate (
870+ data = data ,
871+ split_type = split_type ,
872+ ratio = ratio ,
873+ stratify_by = stratify_by ,
874+ random_state = random_state ,
875+ kwargs = kwargs ,
876+ )
1057877
1058878
1059879def _load_file (file_path : Path ) -> pd .DataFrame :
0 commit comments