Skip to content

Commit 763aac0

Browse files
ahn1340mfeurer
authored andcommitted
Check target type at the beginning of the fitting process. (#506)
* Check target type at the beginning of the fitting process. * . * Fixed minor error in uniitest * . * Add unittest for target type checking. * . * . * Change datasets used in examples from digits to breast_cancer. * [Debug] try with numpy version 1.14.5 * [Debug] Check if numpy version 1.14.6 raises error. * [Debug] try different numpy version * [Debug] Try with latest numpy version * Set numpy version to 1.14.5 * Check target type at the beginning of the fitting process. * . * Fixed minor error in uniitest * . * Add unittest for target type checking. * . * . * [Debug] Check if numpy version 1.14.6 raises error. * Fix numpy version to 1.14.5 * Add comment to Mock in test_type_of_target * Fix line length in example_parallel.py * Fix minor error
1 parent 3f0ee66 commit 763aac0

File tree

2 files changed

+157
-2
lines changed

2 files changed

+157
-2
lines changed

autosklearn/estimators.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from autosklearn.automl import AutoMLClassifier, AutoMLRegressor
55
from autosklearn.util.backend import create
6+
from sklearn.utils.multiclass import type_of_target
67

78

89
class AutoSklearnEstimator(BaseEstimator):
@@ -465,6 +466,18 @@ def fit(self, X, y,
465466
self
466467
467468
"""
469+
# Before running anything else, first check that the
470+
# type of data is compatible with auto-sklearn. Legal target
471+
# types are: binary, multiclass, multilabel-indicator.
472+
target_type = type_of_target(y)
473+
if target_type in ['multiclass-multioutput',
474+
'continuous',
475+
'continuous-multioutput',
476+
'unknown',
477+
]:
478+
raise ValueError("classification with data of type %s is"
479+
" not supported" % target_type)
480+
468481
super().fit(
469482
X=X,
470483
y=y,
@@ -568,6 +581,18 @@ def fit(self, X, y,
568581
self
569582
570583
"""
584+
# Before running anything else, first check that the
585+
# type of data is compatible with auto-sklearn. Legal target
586+
# types are: continuous, binary, multiclass.
587+
target_type = type_of_target(y)
588+
if target_type in ['multiclass-multioutput',
589+
'multilabel-indicator',
590+
'continuous-multioutput',
591+
'unknown',
592+
]:
593+
raise ValueError("regression with data of type %s is not"
594+
" supported" % target_type)
595+
571596
# Fit is supposed to be idempotent!
572597
# But not if we use share_mode.
573598
super().fit(

test/test_automl/test_estimators.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,25 @@ class EstimatorTest(Base, unittest.TestCase):
5050
# self._tearDown(output)
5151

5252
def test_pSMAC_wrong_arguments(self):
53+
X = np.zeros((100, 100))
54+
y = np.zeros((100, ))
5355
self.assertRaisesRegexp(ValueError,
5456
"If shared_mode == True tmp_folder must not "
5557
"be None.",
56-
lambda shared_mode: AutoSklearnClassifier(shared_mode=shared_mode).fit(None, None),
58+
lambda shared_mode:
59+
AutoSklearnClassifier(
60+
shared_mode=shared_mode,
61+
).fit(X, y),
5762
shared_mode=True)
5863

5964
self.assertRaisesRegexp(ValueError,
6065
"If shared_mode == True output_folder must not "
6166
"be None.",
6267
lambda shared_mode, tmp_folder:
63-
AutoSklearnClassifier(shared_mode=shared_mode, tmp_folder=tmp_folder).fit(None, None),
68+
AutoSklearnClassifier(
69+
shared_mode=shared_mode,
70+
tmp_folder=tmp_folder,
71+
).fit(X, y),
6472
shared_mode=True,
6573
tmp_folder='/tmp/duitaredxtvbedb')
6674

@@ -85,6 +93,128 @@ def test_feat_type_wrong_arguments(self):
8593
cls.fit,
8694
X=X, y=y, feat_type=['Car']*100)
8795

96+
# Mock AutoSklearnEstimator.fit so the test doesn't actually run fit().
97+
@unittest.mock.patch('autosklearn.estimators.AutoSklearnEstimator.fit')
98+
def test_type_of_target(self, mock_estimator):
99+
# Test that classifier raises error for illegal target types.
100+
X = np.array([[1, 2],
101+
[2, 3],
102+
[3, 4],
103+
[4, 5],
104+
])
105+
# Possible target types
106+
y_binary = np.array([0, 0, 1, 1])
107+
y_continuous = np.array([0.1, 1.3, 2.1, 4.0])
108+
y_multiclass = np.array([0, 1, 2, 0])
109+
y_multilabel = np.array([[0, 1],
110+
[1, 1],
111+
[1, 0],
112+
[0, 0],
113+
])
114+
y_multiclass_multioutput = np.array([[0, 1],
115+
[1, 3],
116+
[2, 2],
117+
[5, 3],
118+
])
119+
y_continuous_multioutput = np.array([[0.1, 1.5],
120+
[1.2, 3.5],
121+
[2.7, 2.7],
122+
[5.5, 3.9],
123+
])
124+
125+
cls = AutoSklearnClassifier()
126+
# Illegal target types for classification: continuous,
127+
# multiclass-multioutput, continuous-multioutput.
128+
self.assertRaisesRegex(ValueError,
129+
"classification with data of type"
130+
" multiclass-multioutput is not supported",
131+
cls.fit,
132+
X=X,
133+
y=y_multiclass_multioutput,
134+
)
135+
136+
self.assertRaisesRegex(ValueError,
137+
"classification with data of type"
138+
" continuous is not supported",
139+
cls.fit,
140+
X=X,
141+
y=y_continuous,
142+
)
143+
144+
self.assertRaisesRegex(ValueError,
145+
"classification with data of type"
146+
" continuous-multioutput is not supported",
147+
cls.fit,
148+
X=X,
149+
y=y_continuous_multioutput,
150+
)
151+
152+
# Legal target types for classification: binary, multiclass,
153+
# multilabel-indicator.
154+
try:
155+
cls.fit(X, y_binary)
156+
except ValueError:
157+
self.fail("cls.fit() raised ValueError while fitting "
158+
"binary targets")
159+
160+
try:
161+
cls.fit(X, y_multiclass)
162+
except ValueError:
163+
self.fail("cls.fit() raised ValueError while fitting "
164+
"multiclass targets")
165+
166+
try:
167+
cls.fit(X, y_multilabel)
168+
except ValueError:
169+
self.fail("cls.fit() raised ValueError while fitting "
170+
"multilabel-indicator targets")
171+
172+
# Test that regressor raises error for illegal target types.
173+
reg = AutoSklearnRegressor()
174+
# Illegal target types for regression: multiclass-multioutput,
175+
# multilabel-indicator, continuous-multioutput.
176+
self.assertRaisesRegex(ValueError,
177+
"regression with data of type"
178+
" multiclass-multioutput is not supported",
179+
reg.fit,
180+
X=X,
181+
y=y_multiclass_multioutput,
182+
)
183+
184+
self.assertRaisesRegex(ValueError,
185+
"regression with data of type"
186+
" multilabel-indicator is not supported",
187+
reg.fit,
188+
X=X,
189+
y=y_multilabel,
190+
)
191+
192+
self.assertRaisesRegex(ValueError,
193+
"regression with data of type"
194+
" continuous-multioutput is not supported",
195+
reg.fit,
196+
X=X,
197+
y=y_continuous_multioutput,
198+
)
199+
# Legal target types: continuous, binary, multiclass
200+
try:
201+
reg.fit(X, y_continuous)
202+
except ValueError:
203+
self.fail("reg.fit() raised ValueError while fitting "
204+
"continuous targets")
205+
206+
try:
207+
reg.fit(X, y_binary)
208+
except ValueError:
209+
self.fail("reg.fit() raised ValueError while fitting "
210+
"binary targets")
211+
212+
try:
213+
reg.fit(X, y_multiclass)
214+
except ValueError:
215+
self.fail("reg.fit() raised ValueError while fitting "
216+
"multiclass targets")
217+
88218
def test_fit_pSMAC(self):
89219
tmp = os.path.join(self.test_dir, '..', '.tmp_estimator_fit_pSMAC')
90220
output = os.path.join(self.test_dir, '..', '.out_estimator_fit_pSMAC')

0 commit comments

Comments
 (0)