|
42 | 42 | __all__ = ['TrainEvaluator', 'eval_holdout', 'eval_iterative_holdout', |
43 | 43 | 'eval_cv', 'eval_partial_cv', 'eval_partial_cv_iterative'] |
44 | 44 |
|
45 | | -baseCrossValidator_defaults: Dict[str, Dict[str, Optional[Union[int, float, str]]]] = { |
46 | | - 'GroupKFold': {'n_splits': 3}, |
47 | | - 'KFold': {'n_splits': 3, |
48 | | - 'shuffle': False, |
49 | | - 'random_state': None}, |
50 | | - 'LeaveOneGroupOut': {}, |
51 | | - 'LeavePGroupsOut': {'n_groups': 2}, |
52 | | - 'LeaveOneOut': {}, |
53 | | - 'LeavePOut': {'p': 2}, |
54 | | - 'PredefinedSplit': {}, |
55 | | - 'RepeatedKFold': {'n_splits': 5, |
56 | | - 'n_repeats': 10, |
57 | | - 'random_state': None}, |
58 | | - 'RepeatedStratifiedKFold': {'n_splits': 5, |
59 | | - 'n_repeats': 10, |
60 | | - 'random_state': None}, |
61 | | - 'StratifiedKFold': {'n_splits': 3, |
62 | | - 'shuffle': False, |
63 | | - 'random_state': None}, |
64 | | - 'TimeSeriesSplit': {'n_splits': 3, |
65 | | - 'max_train_size': None}, |
66 | | - 'GroupShuffleSplit': {'n_splits': 5, |
67 | | - 'test_size': None, |
68 | | - 'random_state': None}, |
69 | | - 'StratifiedShuffleSplit': {'n_splits': 10, |
70 | | - 'test_size': None, |
71 | | - 'random_state': None}, |
72 | | - 'ShuffleSplit': {'n_splits': 10, |
73 | | - 'test_size': None, |
74 | | - 'random_state': None} |
75 | | - } |
76 | | - |
77 | 45 |
|
78 | 46 | def _get_y_array(y: SUPPORTED_TARGET_TYPES, task_type: int) -> SUPPORTED_TARGET_TYPES: |
79 | 47 | if task_type in CLASSIFICATION_TASKS and task_type != \ |
@@ -1027,69 +995,30 @@ def get_splitter(self, D: AbstractDataManager) -> Union[BaseCrossValidator, _Rep |
1027 | 995 | if self.resampling_strategy_args is None: |
1028 | 996 | self.resampling_strategy_args = {} |
1029 | 997 |
|
1030 | | - if self.resampling_strategy is not None and not isinstance(self.resampling_strategy, str): |
1031 | | - |
1032 | | - if issubclass(self.resampling_strategy, BaseCrossValidator) or \ |
1033 | | - issubclass(self.resampling_strategy, _RepeatedSplits) or \ |
1034 | | - issubclass(self.resampling_strategy, BaseShuffleSplit): |
1035 | | - |
1036 | | - class_name = self.resampling_strategy.__name__ |
1037 | | - if class_name not in baseCrossValidator_defaults: |
1038 | | - raise ValueError('Unknown CrossValidator.') |
1039 | | - ref_arg_dict = baseCrossValidator_defaults[class_name] |
1040 | | - |
1041 | | - y = D.data['Y_train'] |
1042 | | - if (D.info['task'] in CLASSIFICATION_TASKS and |
1043 | | - D.info['task'] != MULTILABEL_CLASSIFICATION) or \ |
1044 | | - (D.info['task'] in REGRESSION_TASKS and |
1045 | | - D.info['task'] != MULTIOUTPUT_REGRESSION): |
1046 | | - |
1047 | | - y = y.ravel() |
1048 | | - if class_name == 'PredefinedSplit': |
1049 | | - if 'test_fold' not in self.resampling_strategy_args: |
1050 | | - raise ValueError('Must provide parameter test_fold' |
1051 | | - ' for class PredefinedSplit.') |
1052 | | - if class_name == 'LeaveOneGroupOut' or \ |
1053 | | - class_name == 'LeavePGroupsOut' or\ |
1054 | | - class_name == 'GroupKFold' or\ |
1055 | | - class_name == 'GroupShuffleSplit': |
1056 | | - if 'groups' not in self.resampling_strategy_args: |
1057 | | - raise ValueError('Must provide parameter groups ' |
1058 | | - 'for chosen CrossValidator.') |
1059 | | - try: |
1060 | | - if np.shape(self.resampling_strategy_args['groups'])[0] != y.shape[0]: |
1061 | | - raise ValueError('Groups must be array-like ' |
1062 | | - 'with shape (n_samples,).') |
1063 | | - except Exception: |
1064 | | - raise ValueError('Groups must be array-like ' |
1065 | | - 'with shape (n_samples,).') |
1066 | | - else: |
1067 | | - if 'groups' in self.resampling_strategy_args: |
1068 | | - if np.shape(self.resampling_strategy_args['groups'])[0] != y.shape[0]: |
1069 | | - raise ValueError('Groups must be array-like' |
1070 | | - ' with shape (n_samples,).') |
1071 | | - |
1072 | | - # Put args in self.resampling_strategy_args |
1073 | | - for key in ref_arg_dict: |
1074 | | - if key == 'n_splits': |
1075 | | - if 'folds' not in self.resampling_strategy_args: |
1076 | | - self.resampling_strategy_args['folds'] = ref_arg_dict['n_splits'] |
1077 | | - else: |
1078 | | - if key not in self.resampling_strategy_args: |
1079 | | - self.resampling_strategy_args[key] = ref_arg_dict[key] |
1080 | | - |
1081 | | - # Instantiate object with args |
1082 | | - init_dict = copy.deepcopy(self.resampling_strategy_args) |
1083 | | - init_dict.pop('groups', None) |
1084 | | - if 'folds' in init_dict: |
1085 | | - init_dict['n_splits'] = init_dict.pop('folds', None) |
1086 | | - assert self.resampling_strategy is not None |
1087 | | - cv = copy.deepcopy(self.resampling_strategy)(**init_dict) |
1088 | | - |
1089 | | - if 'groups' not in self.resampling_strategy_args: |
1090 | | - self.resampling_strategy_args['groups'] = None |
| 998 | + if ( |
| 999 | + self.resampling_strategy is not None |
| 1000 | + and not isinstance(self.resampling_strategy, str) |
| 1001 | + ): |
| 1002 | + if 'groups' not in self.resampling_strategy_args: |
| 1003 | + self.resampling_strategy_args['groups'] = None |
| 1004 | + |
| 1005 | + if isinstance(self.resampling_strategy, (BaseCrossValidator, |
| 1006 | + _RepeatedSplits, |
| 1007 | + BaseShuffleSplit)): |
| 1008 | + self.check_splitter_resampling_strategy( |
| 1009 | + X=D.data['X_train'], y=D.data['Y_train'], |
| 1010 | + groups=self.resampling_strategy_args.get('groups'), |
| 1011 | + task=D.info['task'], |
| 1012 | + resampling_strategy=self.resampling_strategy, |
| 1013 | + ) |
| 1014 | + return self.resampling_strategy |
1091 | 1015 |
|
1092 | | - return cv |
| 1016 | + # If it got to this point, we are dealing with a non-supported |
| 1017 | + # re-sampling strategy |
| 1018 | + raise ValueError("Unsupported resampling strategy {}/{} provided".format( |
| 1019 | + self.resampling_strategy, |
| 1020 | + type(self.resampling_strategy), |
| 1021 | + )) |
1093 | 1022 |
|
1094 | 1023 | y = D.data['Y_train'] |
1095 | 1024 | shuffle = self.resampling_strategy_args.get('shuffle', True) |
@@ -1161,6 +1090,37 @@ def get_splitter(self, D: AbstractDataManager) -> Union[BaseCrossValidator, _Rep |
1161 | 1090 | raise ValueError(self.resampling_strategy) |
1162 | 1091 | return cv |
1163 | 1092 |
|
| 1093 | + @classmethod |
| 1094 | + def check_splitter_resampling_strategy( |
| 1095 | + cls, |
| 1096 | + X: PIPELINE_DATA_DTYPE, |
| 1097 | + y: np.ndarray, |
| 1098 | + task: int, |
| 1099 | + groups: Any, |
| 1100 | + resampling_strategy: Union[BaseCrossValidator, _RepeatedSplits, |
| 1101 | + BaseShuffleSplit], |
| 1102 | + ) -> None: |
| 1103 | + if ( |
| 1104 | + task in CLASSIFICATION_TASKS |
| 1105 | + and task != MULTILABEL_CLASSIFICATION |
| 1106 | + or ( |
| 1107 | + task in REGRESSION_TASKS |
| 1108 | + and task != MULTIOUTPUT_REGRESSION |
| 1109 | + ) |
| 1110 | + ): |
| 1111 | + y = y.ravel() |
| 1112 | + |
| 1113 | + try: |
| 1114 | + resampling_strategy.get_n_splits(X=X, y=y, groups=groups) |
| 1115 | + next(resampling_strategy.split(X=X, y=y, groups=groups)) |
| 1116 | + except Exception as e: |
| 1117 | + raise ValueError("Unsupported resampling strategy " |
| 1118 | + "{}/{} cause exception: {}".format( |
| 1119 | + resampling_strategy, |
| 1120 | + groups, |
| 1121 | + str(e), |
| 1122 | + )) |
| 1123 | + |
1164 | 1124 |
|
1165 | 1125 | # create closure for evaluating an algorithm |
1166 | 1126 | def eval_holdout( |
|
0 commit comments