Skip to content

Commit 4d2e21c

Browse files
author
Yassine Morakakam
committed
save AutoML models
1 parent 5c4d66a commit 4d2e21c

File tree

1 file changed

+100
-28
lines changed

1 file changed

+100
-28
lines changed

autosklearn/automl.py

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy.ma as ma
1111
import scipy.stats
1212
from sklearn.base import BaseEstimator
13+
from sklearn.model_selection._split import _RepeatedSplits, \
14+
BaseShuffleSplit, BaseCrossValidator
1315
from smac.tae.execute_ta_run import StatusType
1416
from smac.stats.stats import Stats
1517
from sklearn.externals import joblib
@@ -133,11 +135,16 @@ def __init__(self,
133135
# After assignging and checking variables...
134136
#self._backend = Backend(self._output_dir, self._tmp_dir)
135137

136-
def fit(self, X, y,
137-
task=MULTICLASS_CLASSIFICATION,
138-
metric=None,
139-
feat_type=None,
140-
dataset_name=None):
138+
def fit(
139+
self, X, y,
140+
task,
141+
metric,
142+
X_test=None,
143+
y_test=None,
144+
feat_type=None,
145+
dataset_name=None,
146+
only_return_configuration_space=False,
147+
):
141148
if not self._shared_mode:
142149
self._backend.context.delete_directories()
143150
else:
@@ -181,13 +188,22 @@ def fit(self, X, y,
181188
'valid feature types, you passed `%s`' % ft)
182189

183190
self._data_memory_limit = None
184-
loaded_data_manager = XYDataManager(X, y,
185-
task=task,
186-
feat_type=feat_type,
187-
dataset_name=dataset_name)
191+
loaded_data_manager = XYDataManager(
192+
X, y,
193+
X_test=X_test,
194+
y_test=y_test,
195+
task=task,
196+
feat_type=feat_type,
197+
dataset_name=dataset_name,
198+
)
188199

189-
return self._fit(loaded_data_manager, metric)
200+
return self._fit(
201+
loaded_data_manager,
202+
metric,
203+
only_return_configuration_space,
204+
)
190205

206+
# TODO this is very old code which can be dropped!
191207
def fit_automl_dataset(self, dataset, metric):
192208
self._stopwatch = StopWatch()
193209
self._backend.save_start_time(self._seed)
@@ -280,7 +296,7 @@ def _do_dummy_prediction(self, datamanager, num_run):
280296

281297
return ta.num_run
282298

283-
def _fit(self, datamanager, metric):
299+
def _fit(self, datamanager, metric, only_return_configuration_space=False):
284300
# Reset learnt stuff
285301
self.models_ = None
286302
self.ensemble_ = None
@@ -296,9 +312,13 @@ def _fit(self, datamanager, metric):
296312
raise ValueError("List member '%s' for argument "
297313
"'disable_evaluator_output' must be one "
298314
"of " + str(allowed_elements))
299-
if self._resampling_strategy not in ['holdout', 'holdout-iterative-fit',
300-
'cv', 'partial-cv',
301-
'partial-cv-iterative-fit']:
315+
if self._resampling_strategy not in [
316+
'holdout', 'holdout-iterative-fit',
317+
'cv', 'partial-cv',
318+
'partial-cv-iterative-fit'] \
319+
and not issubclass(self._resampling_strategy, BaseCrossValidator)\
320+
and not issubclass(self._resampling_strategy, _RepeatedSplits)\
321+
and not issubclass(self._resampling_strategy, BaseShuffleSplit):
302322
raise ValueError('Illegal resampling strategy: %s' %
303323
self._resampling_strategy)
304324
if self._resampling_strategy in ['partial-cv', 'partial-cv-iterative-fit'] \
@@ -354,6 +374,8 @@ def _fit(self, datamanager, metric):
354374
exclude_estimators=self._exclude_estimators,
355375
include_preprocessors=self._include_preprocessors,
356376
exclude_preprocessors=self._exclude_preprocessors)
377+
if only_return_configuration_space:
378+
return self.configuration_space
357379

358380
# == RUN ensemble builder
359381
# Do this before calculating the meta-features to make sure that the
@@ -532,7 +554,7 @@ def predict(self, X, batch_size=None, n_jobs=1):
532554
# Each process computes predictions in chunks of batch_size rows.
533555
all_predictions = joblib.Parallel(n_jobs=n_jobs)(
534556
joblib.delayed(_model_predict)(self, X, batch_size, identifier)
535-
for identifier in self.ensemble_.get_model_identifiers(self.models_))
557+
for identifier in self.ensemble_.get_model_identifiers())
536558

537559
if len(all_predictions) == 0:
538560
raise ValueError('Something went wrong generating the predictions. '
@@ -607,7 +629,8 @@ def _get_ensemble_process(self, time_left_for_ensembles,
607629
seed=self._seed,
608630
shared_mode=self._shared_mode,
609631
precision=precision,
610-
max_iterations=max_iterations)
632+
max_iterations=max_iterations,
633+
read_at_most=np.inf)
611634

612635
def _load_models(self):
613636
if self._shared_mode:
@@ -811,7 +834,8 @@ def __init__(self, *args, **kwargs):
811834

812835
def _perform_input_checks(self, X, y):
813836
X = self._check_X(X)
814-
y = self._check_y(y)
837+
if y is not None:
838+
y = self._check_y(y)
815839
return X, y
816840

817841
def _check_X(self, X):
@@ -865,12 +889,21 @@ def __init__(self, *args, **kwargs):
865889
'multiclass': MULTICLASS_CLASSIFICATION,
866890
'binary': BINARY_CLASSIFICATION}
867891

868-
def fit(self, X, y,
869-
metric=None,
870-
loss=None,
871-
feat_type=None,
872-
dataset_name=None):
892+
def fit(
893+
self, X, y,
894+
X_test=None,
895+
y_test=None,
896+
metric=None,
897+
feat_type=None,
898+
dataset_name=None,
899+
only_return_configuration_space=False,
900+
):
873901
X, y = self._perform_input_checks(X, y)
902+
if X_test is not None:
903+
X_test, y_test = self._perform_input_checks(X_test, y_test)
904+
if len(y.shape) != len(y_test.shape):
905+
raise ValueError('Target value shapes do not match: %s vs %s'
906+
% (y.shape, y_test.shape))
874907

875908
y_task = type_of_target(y)
876909
task = self._task_mapping.get(y_task)
@@ -884,8 +917,31 @@ def fit(self, X, y,
884917
metric = accuracy
885918

886919
y, self._classes, self._n_classes = self._process_target_classes(y)
887-
888-
return super().fit(X, y, task, metric, feat_type, dataset_name)
920+
if y_test is not None:
921+
# Map test values to actual values - TODO: copy to all kinds of
922+
# other parts in this code and test it!!!
923+
y_test_new = []
924+
for output_idx in range(len(self._classes)):
925+
mapping = {self._classes[output_idx][idx]: idx
926+
for idx in range(len(self._classes[output_idx]))}
927+
enumeration = y_test if len(self._classes) == 1 else y_test[output_idx]
928+
y_test_new.append(
929+
np.array([mapping[value] for value in enumeration])
930+
)
931+
y_test = np.array(y_test_new)
932+
if self._n_outputs == 1:
933+
y_test = y_test.flatten()
934+
935+
return super().fit(
936+
X, y,
937+
X_test=X_test,
938+
y_test=y_test,
939+
task=task,
940+
metric=metric,
941+
feat_type=feat_type,
942+
dataset_name=dataset_name,
943+
only_return_configuration_space=only_return_configuration_space,
944+
)
889945

890946
def fit_ensemble(self, y, task=None, metric=None, precision='32',
891947
dataset_name=None, ensemble_nbest=None,
@@ -918,7 +974,7 @@ def _process_target_classes(self, y):
918974
_classes.append(classes_k)
919975
_n_classes.append(classes_k.shape[0])
920976

921-
self._n_classes = np.array(_n_classes, dtype=np.int)
977+
_n_classes = np.array(_n_classes, dtype=np.int)
922978

923979
return y, _classes, _n_classes
924980

@@ -948,16 +1004,32 @@ def predict_proba(self, X, batch_size=None, n_jobs=1):
9481004

9491005

9501006
class AutoMLRegressor(BaseAutoML):
951-
def fit(self, X, y, metric=None, feat_type=None, dataset_name=None):
1007+
def fit(
1008+
self, X, y,
1009+
X_test=None,
1010+
y_test=None,
1011+
metric=None,
1012+
feat_type=None,
1013+
dataset_name=None,
1014+
only_return_configuration_space=False,
1015+
):
9521016
X, y = super()._perform_input_checks(X, y)
9531017
_n_outputs = 1 if len(y.shape) == 1 else y.shape[1]
9541018
if _n_outputs > 1:
9551019
raise NotImplementedError(
9561020
'Multi-output regression is not implemented.')
9571021
if metric is None:
9581022
metric = r2
959-
return super().fit(X, y, task=REGRESSION, metric=metric,
960-
feat_type=feat_type, dataset_name=dataset_name)
1023+
return super().fit(
1024+
X, y,
1025+
X_test=X_test,
1026+
y_test=y_test,
1027+
task=REGRESSION,
1028+
metric=metric,
1029+
feat_type=feat_type,
1030+
dataset_name=dataset_name,
1031+
only_return_configuration_space=only_return_configuration_space,
1032+
)
9611033

9621034
def fit_ensemble(self, y, task=None, metric=None, precision='32',
9631035
dataset_name=None, ensemble_nbest=None,

0 commit comments

Comments
 (0)