Skip to content

Commit b2c5c3c

Browse files
Uses isinstance instead of issubclass in resampling strategy (#1160)
* Move to isinstance * Fixed unit test * Fix unit test * Simplify test * no class name * elif is better
1 parent 43a3de5 commit b2c5c3c

File tree

5 files changed

+242
-293
lines changed

5 files changed

+242
-293
lines changed

autosklearn/automl.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from autosklearn.evaluation import ExecuteTaFuncWithQueue, get_cost_of_crash
4747
from autosklearn.evaluation.abstract_evaluator import _fit_and_suppress_warnings
48-
from autosklearn.evaluation.train_evaluator import _fit_with_budget
48+
from autosklearn.evaluation.train_evaluator import TrainEvaluator, _fit_with_budget
4949
from autosklearn.metrics import calculate_metric
5050
from autosklearn.util.backend import Backend
5151
from autosklearn.util.stopwatch import StopWatch
@@ -164,32 +164,6 @@ def __init__(self,
164164
self._scoring_functions = scoring_functions if scoring_functions is not None else []
165165
self._resampling_strategy_arguments = resampling_strategy_arguments \
166166
if resampling_strategy_arguments is not None else {}
167-
if self._resampling_strategy not in ['holdout',
168-
'holdout-iterative-fit',
169-
'cv',
170-
'cv-iterative-fit',
171-
'partial-cv',
172-
'partial-cv-iterative-fit',
173-
] \
174-
and not issubclass(self._resampling_strategy, BaseCrossValidator)\
175-
and not issubclass(self._resampling_strategy, _RepeatedSplits)\
176-
and not issubclass(self._resampling_strategy, BaseShuffleSplit):
177-
raise ValueError('Illegal resampling strategy: %s' %
178-
self._resampling_strategy)
179-
180-
if self._resampling_strategy in ['partial-cv',
181-
'partial-cv-iterative-fit',
182-
] \
183-
and self._ensemble_size != 0:
184-
raise ValueError("Resampling strategy %s cannot be used "
185-
"together with ensembles." % self._resampling_strategy)
186-
if self._resampling_strategy in ['partial-cv',
187-
'cv',
188-
'cv-iterative-fit',
189-
'partial-cv-iterative-fit',
190-
]\
191-
and 'folds' not in self._resampling_strategy_arguments:
192-
self._resampling_strategy_arguments['folds'] = 5
193167
self._n_jobs = n_jobs
194168
self._dask_client = dask_client
195169

@@ -506,6 +480,15 @@ def fit(
506480
task=self._task,
507481
)
508482

483+
# Check the re-sampling strategy
484+
try:
485+
self._check_resampling_strategy(
486+
X=X, y=y, task=task,
487+
)
488+
except Exception as e:
489+
self._fit_cleanup()
490+
raise e
491+
509492
# Reset learnt stuff
510493
self.models_ = None
511494
self.cv_models_ = None
@@ -832,6 +815,63 @@ def _fit_cleanup(self):
832815
self._clean_logger()
833816
return
834817

818+
def _check_resampling_strategy(
819+
self,
820+
X: SUPPORTED_FEAT_TYPES,
821+
y: SUPPORTED_TARGET_TYPES,
822+
task: int,
823+
) -> None:
824+
"""
825+
This method centralizes the checks for resampling strategies
826+
827+
Parameters
828+
----------
829+
X: (SUPPORTED_FEAT_TYPES)
830+
Input features for the given task
831+
y: (SUPPORTED_TARGET_TYPES)
832+
Input targets for the given task
833+
task: (task)
834+
Integer describing a supported task type, like BINARY_CLASSIFICATION
835+
"""
836+
is_split_object = isinstance(
837+
self._resampling_strategy,
838+
(BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit)
839+
)
840+
841+
if self._resampling_strategy not in [
842+
'holdout',
843+
'holdout-iterative-fit',
844+
'cv',
845+
'cv-iterative-fit',
846+
'partial-cv',
847+
'partial-cv-iterative-fit',
848+
] and not is_split_object:
849+
raise ValueError('Illegal resampling strategy: %s' % self._resampling_strategy)
850+
851+
elif is_split_object:
852+
TrainEvaluator.check_splitter_resampling_strategy(
853+
X=X, y=y, task=task,
854+
groups=self._resampling_strategy_arguments.get('groups', None),
855+
resampling_strategy=self._resampling_strategy,
856+
)
857+
858+
elif self._resampling_strategy in [
859+
'partial-cv',
860+
'partial-cv-iterative-fit',
861+
] and self._ensemble_size != 0:
862+
raise ValueError("Resampling strategy %s cannot be used "
863+
"together with ensembles." % self._resampling_strategy)
864+
865+
elif self._resampling_strategy in [
866+
'partial-cv',
867+
'cv',
868+
'cv-iterative-fit',
869+
'partial-cv-iterative-fit',
870+
] and 'folds' not in self._resampling_strategy_arguments:
871+
self._resampling_strategy_arguments['folds'] = 5
872+
873+
return
874+
835875
@staticmethod
836876
def subsample_if_too_large(
837877
X: SUPPORTED_FEAT_TYPES,

autosklearn/evaluation/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,9 @@ def __init__(
139139
eval_function = autosklearn.evaluation.train_evaluator.eval_iterative_holdout
140140
elif resampling_strategy == 'cv-iterative-fit':
141141
eval_function = autosklearn.evaluation.train_evaluator.eval_iterative_cv
142-
elif resampling_strategy == 'cv' or (
143-
isinstance(resampling_strategy, type) and (
144-
issubclass(resampling_strategy, BaseCrossValidator) or
145-
issubclass(resampling_strategy, _RepeatedSplits) or
146-
issubclass(resampling_strategy, BaseShuffleSplit)
147-
)
148-
):
142+
elif resampling_strategy == 'cv' or isinstance(resampling_strategy, (
143+
BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit)
144+
):
149145
eval_function = autosklearn.evaluation.train_evaluator.eval_cv
150146
elif resampling_strategy == 'partial-cv':
151147
eval_function = autosklearn.evaluation.train_evaluator.eval_partial_cv

autosklearn/evaluation/train_evaluator.py

Lines changed: 54 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -42,38 +42,6 @@
4242
__all__ = ['TrainEvaluator', 'eval_holdout', 'eval_iterative_holdout',
4343
'eval_cv', 'eval_partial_cv', 'eval_partial_cv_iterative']
4444

45-
baseCrossValidator_defaults: Dict[str, Dict[str, Optional[Union[int, float, str]]]] = {
46-
'GroupKFold': {'n_splits': 3},
47-
'KFold': {'n_splits': 3,
48-
'shuffle': False,
49-
'random_state': None},
50-
'LeaveOneGroupOut': {},
51-
'LeavePGroupsOut': {'n_groups': 2},
52-
'LeaveOneOut': {},
53-
'LeavePOut': {'p': 2},
54-
'PredefinedSplit': {},
55-
'RepeatedKFold': {'n_splits': 5,
56-
'n_repeats': 10,
57-
'random_state': None},
58-
'RepeatedStratifiedKFold': {'n_splits': 5,
59-
'n_repeats': 10,
60-
'random_state': None},
61-
'StratifiedKFold': {'n_splits': 3,
62-
'shuffle': False,
63-
'random_state': None},
64-
'TimeSeriesSplit': {'n_splits': 3,
65-
'max_train_size': None},
66-
'GroupShuffleSplit': {'n_splits': 5,
67-
'test_size': None,
68-
'random_state': None},
69-
'StratifiedShuffleSplit': {'n_splits': 10,
70-
'test_size': None,
71-
'random_state': None},
72-
'ShuffleSplit': {'n_splits': 10,
73-
'test_size': None,
74-
'random_state': None}
75-
}
76-
7745

7846
def _get_y_array(y: SUPPORTED_TARGET_TYPES, task_type: int) -> SUPPORTED_TARGET_TYPES:
7947
if task_type in CLASSIFICATION_TASKS and task_type != \
@@ -1027,69 +995,30 @@ def get_splitter(self, D: AbstractDataManager) -> Union[BaseCrossValidator, _Rep
1027995
if self.resampling_strategy_args is None:
1028996
self.resampling_strategy_args = {}
1029997

1030-
if self.resampling_strategy is not None and not isinstance(self.resampling_strategy, str):
1031-
1032-
if issubclass(self.resampling_strategy, BaseCrossValidator) or \
1033-
issubclass(self.resampling_strategy, _RepeatedSplits) or \
1034-
issubclass(self.resampling_strategy, BaseShuffleSplit):
1035-
1036-
class_name = self.resampling_strategy.__name__
1037-
if class_name not in baseCrossValidator_defaults:
1038-
raise ValueError('Unknown CrossValidator.')
1039-
ref_arg_dict = baseCrossValidator_defaults[class_name]
1040-
1041-
y = D.data['Y_train']
1042-
if (D.info['task'] in CLASSIFICATION_TASKS and
1043-
D.info['task'] != MULTILABEL_CLASSIFICATION) or \
1044-
(D.info['task'] in REGRESSION_TASKS and
1045-
D.info['task'] != MULTIOUTPUT_REGRESSION):
1046-
1047-
y = y.ravel()
1048-
if class_name == 'PredefinedSplit':
1049-
if 'test_fold' not in self.resampling_strategy_args:
1050-
raise ValueError('Must provide parameter test_fold'
1051-
' for class PredefinedSplit.')
1052-
if class_name == 'LeaveOneGroupOut' or \
1053-
class_name == 'LeavePGroupsOut' or\
1054-
class_name == 'GroupKFold' or\
1055-
class_name == 'GroupShuffleSplit':
1056-
if 'groups' not in self.resampling_strategy_args:
1057-
raise ValueError('Must provide parameter groups '
1058-
'for chosen CrossValidator.')
1059-
try:
1060-
if np.shape(self.resampling_strategy_args['groups'])[0] != y.shape[0]:
1061-
raise ValueError('Groups must be array-like '
1062-
'with shape (n_samples,).')
1063-
except Exception:
1064-
raise ValueError('Groups must be array-like '
1065-
'with shape (n_samples,).')
1066-
else:
1067-
if 'groups' in self.resampling_strategy_args:
1068-
if np.shape(self.resampling_strategy_args['groups'])[0] != y.shape[0]:
1069-
raise ValueError('Groups must be array-like'
1070-
' with shape (n_samples,).')
1071-
1072-
# Put args in self.resampling_strategy_args
1073-
for key in ref_arg_dict:
1074-
if key == 'n_splits':
1075-
if 'folds' not in self.resampling_strategy_args:
1076-
self.resampling_strategy_args['folds'] = ref_arg_dict['n_splits']
1077-
else:
1078-
if key not in self.resampling_strategy_args:
1079-
self.resampling_strategy_args[key] = ref_arg_dict[key]
1080-
1081-
# Instantiate object with args
1082-
init_dict = copy.deepcopy(self.resampling_strategy_args)
1083-
init_dict.pop('groups', None)
1084-
if 'folds' in init_dict:
1085-
init_dict['n_splits'] = init_dict.pop('folds', None)
1086-
assert self.resampling_strategy is not None
1087-
cv = copy.deepcopy(self.resampling_strategy)(**init_dict)
1088-
1089-
if 'groups' not in self.resampling_strategy_args:
1090-
self.resampling_strategy_args['groups'] = None
998+
if (
999+
self.resampling_strategy is not None
1000+
and not isinstance(self.resampling_strategy, str)
1001+
):
1002+
if 'groups' not in self.resampling_strategy_args:
1003+
self.resampling_strategy_args['groups'] = None
1004+
1005+
if isinstance(self.resampling_strategy, (BaseCrossValidator,
1006+
_RepeatedSplits,
1007+
BaseShuffleSplit)):
1008+
self.check_splitter_resampling_strategy(
1009+
X=D.data['X_train'], y=D.data['Y_train'],
1010+
groups=self.resampling_strategy_args.get('groups'),
1011+
task=D.info['task'],
1012+
resampling_strategy=self.resampling_strategy,
1013+
)
1014+
return self.resampling_strategy
10911015

1092-
return cv
1016+
# If it got to this point, we are dealing with a non-supported
1017+
# re-sampling strategy
1018+
raise ValueError("Unsupported resampling strategy {}/{} provided".format(
1019+
self.resampling_strategy,
1020+
type(self.resampling_strategy),
1021+
))
10931022

10941023
y = D.data['Y_train']
10951024
shuffle = self.resampling_strategy_args.get('shuffle', True)
@@ -1161,6 +1090,37 @@ def get_splitter(self, D: AbstractDataManager) -> Union[BaseCrossValidator, _Rep
11611090
raise ValueError(self.resampling_strategy)
11621091
return cv
11631092

1093+
@classmethod
1094+
def check_splitter_resampling_strategy(
1095+
cls,
1096+
X: PIPELINE_DATA_DTYPE,
1097+
y: np.ndarray,
1098+
task: int,
1099+
groups: Any,
1100+
resampling_strategy: Union[BaseCrossValidator, _RepeatedSplits,
1101+
BaseShuffleSplit],
1102+
) -> None:
1103+
if (
1104+
task in CLASSIFICATION_TASKS
1105+
and task != MULTILABEL_CLASSIFICATION
1106+
or (
1107+
task in REGRESSION_TASKS
1108+
and task != MULTIOUTPUT_REGRESSION
1109+
)
1110+
):
1111+
y = y.ravel()
1112+
1113+
try:
1114+
resampling_strategy.get_n_splits(X=X, y=y, groups=groups)
1115+
next(resampling_strategy.split(X=X, y=y, groups=groups))
1116+
except Exception as e:
1117+
raise ValueError("Unsupported resampling strategy "
1118+
"{}/{} cause exception: {}".format(
1119+
resampling_strategy,
1120+
groups,
1121+
str(e),
1122+
))
1123+
11641124

11651125
# create closure for evaluating an algorithm
11661126
def eval_holdout(

examples/40_advanced/example_resampling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,15 @@
9898
# data by the first feature. In practice, one would use a splitting according
9999
# to the use case at hand.
100100

101-
resampling_strategy = sklearn.model_selection.PredefinedSplit
102-
resampling_strategy_arguments = {'test_fold': np.where(X_train[:, 0] < np.mean(X_train[:, 0]))[0]}
101+
resampling_strategy = sklearn.model_selection.PredefinedSplit(
102+
test_fold=np.where(X_train[:, 0] < np.mean(X_train[:, 0]))[0])
103103

104104
automl = autosklearn.classification.AutoSklearnClassifier(
105105
time_left_for_this_task=120,
106106
per_run_time_limit=30,
107107
tmp_folder='/tmp/autosklearn_resampling_example_tmp',
108108
disable_evaluator_output=False,
109109
resampling_strategy=resampling_strategy,
110-
resampling_strategy_arguments=resampling_strategy_arguments,
111110
)
112111
automl.fit(X_train, y_train, dataset_name='breast_cancer')
113112

0 commit comments

Comments
 (0)