@@ -50,17 +50,25 @@ class EstimatorTest(Base, unittest.TestCase):
5050 # self._tearDown(output)
5151
5252 def test_pSMAC_wrong_arguments (self ):
53+ X = np .zeros ((100 , 100 ))
54+ y = np .zeros ((100 , ))
5355 self .assertRaisesRegexp (ValueError ,
5456 "If shared_mode == True tmp_folder must not "
5557 "be None." ,
56- lambda shared_mode : AutoSklearnClassifier (shared_mode = shared_mode ).fit (None , None ),
58+ lambda shared_mode :
59+ AutoSklearnClassifier (
60+ shared_mode = shared_mode ,
61+ ).fit (X , y ),
5762 shared_mode = True )
5863
5964 self .assertRaisesRegexp (ValueError ,
6065 "If shared_mode == True output_folder must not "
6166 "be None." ,
6267 lambda shared_mode , tmp_folder :
63- AutoSklearnClassifier (shared_mode = shared_mode , tmp_folder = tmp_folder ).fit (None , None ),
68+ AutoSklearnClassifier (
69+ shared_mode = shared_mode ,
70+ tmp_folder = tmp_folder ,
71+ ).fit (X , y ),
6472 shared_mode = True ,
6573 tmp_folder = '/tmp/duitaredxtvbedb' )
6674
@@ -85,6 +93,128 @@ def test_feat_type_wrong_arguments(self):
8593 cls .fit ,
8694 X = X , y = y , feat_type = ['Car' ]* 100 )
8795
96+ # Mock AutoSklearnEstimator.fit so the test doesn't actually run fit().
97+ @unittest .mock .patch ('autosklearn.estimators.AutoSklearnEstimator.fit' )
98+ def test_type_of_target (self , mock_estimator ):
99+ # Test that classifier raises error for illegal target types.
100+ X = np .array ([[1 , 2 ],
101+ [2 , 3 ],
102+ [3 , 4 ],
103+ [4 , 5 ],
104+ ])
105+ # Possible target types
106+ y_binary = np .array ([0 , 0 , 1 , 1 ])
107+ y_continuous = np .array ([0.1 , 1.3 , 2.1 , 4.0 ])
108+ y_multiclass = np .array ([0 , 1 , 2 , 0 ])
109+ y_multilabel = np .array ([[0 , 1 ],
110+ [1 , 1 ],
111+ [1 , 0 ],
112+ [0 , 0 ],
113+ ])
114+ y_multiclass_multioutput = np .array ([[0 , 1 ],
115+ [1 , 3 ],
116+ [2 , 2 ],
117+ [5 , 3 ],
118+ ])
119+ y_continuous_multioutput = np .array ([[0.1 , 1.5 ],
120+ [1.2 , 3.5 ],
121+ [2.7 , 2.7 ],
122+ [5.5 , 3.9 ],
123+ ])
124+
125+ cls = AutoSklearnClassifier ()
126+ # Illegal target types for classification: continuous,
127+ # multiclass-multioutput, continuous-multioutput.
128+ self .assertRaisesRegex (ValueError ,
129+ "classification with data of type"
130+ " multiclass-multioutput is not supported" ,
131+ cls .fit ,
132+ X = X ,
133+ y = y_multiclass_multioutput ,
134+ )
135+
136+ self .assertRaisesRegex (ValueError ,
137+ "classification with data of type"
138+ " continuous is not supported" ,
139+ cls .fit ,
140+ X = X ,
141+ y = y_continuous ,
142+ )
143+
144+ self .assertRaisesRegex (ValueError ,
145+ "classification with data of type"
146+ " continuous-multioutput is not supported" ,
147+ cls .fit ,
148+ X = X ,
149+ y = y_continuous_multioutput ,
150+ )
151+
152+ # Legal target types for classification: binary, multiclass,
153+ # multilabel-indicator.
154+ try :
155+ cls .fit (X , y_binary )
156+ except ValueError :
157+ self .fail ("cls.fit() raised ValueError while fitting "
158+ "binary targets" )
159+
160+ try :
161+ cls .fit (X , y_multiclass )
162+ except ValueError :
163+ self .fail ("cls.fit() raised ValueError while fitting "
164+ "multiclass targets" )
165+
166+ try :
167+ cls .fit (X , y_multilabel )
168+ except ValueError :
169+ self .fail ("cls.fit() raised ValueError while fitting "
170+ "multilabel-indicator targets" )
171+
172+ # Test that regressor raises error for illegal target types.
173+ reg = AutoSklearnRegressor ()
174+ # Illegal target types for regression: multiclass-multioutput,
175+ # multilabel-indicator, continuous-multioutput.
176+ self .assertRaisesRegex (ValueError ,
177+ "regression with data of type"
178+ " multiclass-multioutput is not supported" ,
179+ reg .fit ,
180+ X = X ,
181+ y = y_multiclass_multioutput ,
182+ )
183+
184+ self .assertRaisesRegex (ValueError ,
185+ "regression with data of type"
186+ " multilabel-indicator is not supported" ,
187+ reg .fit ,
188+ X = X ,
189+ y = y_multilabel ,
190+ )
191+
192+ self .assertRaisesRegex (ValueError ,
193+ "regression with data of type"
194+ " continuous-multioutput is not supported" ,
195+ reg .fit ,
196+ X = X ,
197+ y = y_continuous_multioutput ,
198+ )
199+ # Legal target types: continuous, binary, multiclass
200+ try :
201+ reg .fit (X , y_continuous )
202+ except ValueError :
203+ self .fail ("reg.fit() raised ValueError while fitting "
204+ "continuous targets" )
205+
206+ try :
207+ reg .fit (X , y_binary )
208+ except ValueError :
209+ self .fail ("reg.fit() raised ValueError while fitting "
210+ "binary targets" )
211+
212+ try :
213+ reg .fit (X , y_multiclass )
214+ except ValueError :
215+ self .fail ("reg.fit() raised ValueError while fitting "
216+ "multiclass targets" )
217+
88218 def test_fit_pSMAC (self ):
89219 tmp = os .path .join (self .test_dir , '..' , '.tmp_estimator_fit_pSMAC' )
90220 output = os .path .join (self .test_dir , '..' , '.out_estimator_fit_pSMAC' )
0 commit comments