Skip to content

Commit f8b04a7

Browse files
committed
Merge branch 'development_java'
2 parents 39974ba + 9d8d5fb commit f8b04a7

File tree

166 files changed

+5747
-3092
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

166 files changed

+5747
-3092
lines changed

autosklearn/automl.py

Lines changed: 70 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def __init__(self,
187187
self._metric = None
188188
self._label_num = None
189189
self.models_ = None
190-
self.ensemble_indices_ = None
190+
self.ensemble_ = None
191+
self._can_predict = False
191192

192193
self._debug_mode = debug_mode
193194
self._backend = Backend(self._output_dir, self._tmp_dir)
@@ -242,9 +243,14 @@ def fit(self, X, y,
242243
raise ValueError('Array feat_type does not have same number of '
243244
'variables as X has features. %d vs %d.' %
244245
(len(feat_type), X.shape[1]))
245-
if feat_type is not None and not all([isinstance(f, bool)
246+
if feat_type is not None and not all([isinstance(f, str)
246247
for f in feat_type]):
247-
raise ValueError('Array feat_type must only contain bools.')
248+
raise ValueError('Array feat_type must only contain strings.')
249+
if feat_type is not None:
250+
for ft in feat_type:
251+
if ft.lower() not in ['categorical', 'numerical']:
252+
raise ValueError('Only `Categorical` and `Numerical` are '
253+
'valid feature types, you passed `%s`' % ft)
248254

249255
loaded_data_manager = XYDataManager(X, y,
250256
task=task,
@@ -298,16 +304,19 @@ def _print_load_time(basename, time_left_for_this_task,
298304
return time_for_load_data
299305

300306
def _do_dummy_prediction(self, datamanager):
307+
self._logger.info("Starting to create dummy predictions.")
301308
autosklearn.cli.base_interface.main(datamanager,
302309
self._resampling_strategy,
303310
None,
304311
None,
305-
mode_args=self._resampling_strategy_arguments)
312+
mode_args=self._resampling_strategy_arguments,
313+
output_dir=self._tmp_dir)
314+
self._logger.info("Finished creating dummy predictions.")
306315

307316
def _fit(self, datamanager):
308317
# Reset learnt stuff
309318
self.models_ = None
310-
self.ensemble_indices_ = None
319+
self.ensemble_ = None
311320

312321
# Check arguments prior to doing anything!
313322
if self._resampling_strategy not in ['holdout', 'holdout-iterative-fit',
@@ -352,7 +361,8 @@ def _fit(self, datamanager):
352361
self._logger)
353362

354363
# == Perform dummy predictions
355-
self._do_dummy_prediction(datamanager)
364+
if self._resampling_strategy in ['holdout', 'holdout-iterative-fit']:
365+
self._do_dummy_prediction(datamanager)
356366

357367
# = Create a searchspace
358368
# Do this before One Hot Encoding to make sure that it creates a
@@ -371,6 +381,12 @@ def _fit(self, datamanager):
371381
self._include_preprocessors)
372382
self.configuration_space_created_hook(datamanager)
373383

384+
# == RUN ensemble builder
385+
# Do this before calculating the meta-features to make sure that the
386+
# dummy predictions are actually included in the ensemble even if
387+
# calculating the meta-features takes very long
388+
proc_ensembles = self.run_ensemble_builder()
389+
374390
# == Calculate metafeatures
375391
meta_features = _calculate_metafeatures(
376392
data_feat_type=datamanager.feat_type,
@@ -463,9 +479,12 @@ def _fit(self, datamanager):
463479
TASK_TYPES_TO_STRING[datamanager.info["task"]])
464480

465481
if config is not None:
466-
configuration = Configuration(self.configuration_space, config)
467-
config_string = convert_conf2smac_string(configuration)
468-
initial_configurations = [config_string] + initial_configurations
482+
try:
483+
configuration = Configuration(self.configuration_space, config)
484+
config_string = convert_conf2smac_string(configuration)
485+
initial_configurations = [config_string] + initial_configurations
486+
except ValueError:
487+
pass
469488

470489
# == RUN SMAC
471490
proc_smac = run_smac(tmp_dir=self._tmp_dir, basename=self._dataset_name,
@@ -481,9 +500,6 @@ def _fit(self, datamanager):
481500
resampling_strategy_arguments=self._resampling_strategy_arguments,
482501
shared_mode=self._shared_mode)
483502

484-
# == RUN ensemble builder
485-
proc_ensembles = self.run_ensemble_builder()
486-
487503
procs = []
488504

489505
if proc_smac is not None:
@@ -554,26 +570,43 @@ def run_ensemble_builder(self,
554570
'size 0.')
555571
return None
556572

573+
def refit(self, X, y):
574+
if self._keep_models is not True:
575+
raise ValueError(
576+
"Predict can only be called if 'keep_models==True'")
577+
if self.models_ is None or len(self.models_) == 0 or \
578+
self.ensemble_ is None:
579+
self._load_models()
580+
581+
for identifier in self.models_:
582+
if identifier in self.ensemble_.get_model_identifiers():
583+
model = self.models_[identifier]
584+
# this updates the model inplace, it can then later be used in
585+
# predict method
586+
model.fit(X.copy(), y.copy())
587+
588+
self._can_predict = True
589+
557590
def predict(self, X):
591+
return np.argmax(self.predict_proba(X), axis=1)
592+
593+
def predict_proba(self, X):
558594
if self._keep_models is not True:
559595
raise ValueError(
560596
"Predict can only be called if 'keep_models==True'")
561-
if self._resampling_strategy not in ['holdout',
562-
'holdout-iterative-fit']:
597+
if not self._can_predict and \
598+
self._resampling_strategy not in \
599+
['holdout', 'holdout-iterative-fit']:
563600
raise NotImplementedError(
564601
'Predict is currently only implemented for resampling '
565602
'strategy holdout.')
566603

567-
if self.models_ is None or len(self.models_) == 0 or len(
568-
self.ensemble_indices_) == 0:
604+
if self.models_ is None or len(self.models_) == 0 or \
605+
self.ensemble_ is None:
569606
self._load_models()
570607

571-
predictions = []
572-
for identifier in self.models_:
573-
if identifier not in self.ensemble_indices_:
574-
continue
575-
576-
weight = self.ensemble_indices_[identifier]
608+
all_predictions = []
609+
for identifier in self.ensemble_.get_model_identifiers():
577610
model = self.models_[identifier]
578611

579612
X_ = X.copy()
@@ -588,16 +621,16 @@ def predict(self, X):
588621
"while X_.shape is %s" %
589622
(model, str(prediction.shape),
590623
str(X_.shape)))
591-
predictions.append(prediction * weight)
624+
all_predictions.append(prediction)
592625

593-
if len(predictions) == 0:
626+
if len(all_predictions) == 0:
594627
raise ValueError('Something went wrong generating the predictions. '
595628
'The ensemble should consist of the following '
596629
'models: %s, the following models were loaded: '
597630
'%s' % (str(list(self.ensemble_indices_.keys())),
598631
str(list(self.models_.keys()))))
599632

600-
predictions = np.sum(np.array(predictions), axis=0)
633+
predictions = self.ensemble_.predict(all_predictions)
601634
return predictions
602635

603636
def _load_models(self):
@@ -606,46 +639,32 @@ def _load_models(self):
606639
else:
607640
seed = self._seed
608641

609-
self.models_ = self._backend.load_all_models(seed)
642+
self.ensemble_ = self._backend.load_ensemble(seed)
643+
if self.ensemble_:
644+
identifiers = self.ensemble_.identifiers_
645+
self.models_ = self._backend.load_models_by_identifiers(identifiers)
646+
else:
647+
self.models_ = self._backend.load_all_models(seed)
648+
610649
if len(self.models_) == 0:
611650
raise ValueError('No models fitted!')
612651

613-
self.ensemble_indices_ = self._backend.load_ensemble_indices_weights(
614-
seed)
615652

616653
def score(self, X, y):
617654
# fix: Consider only index 1 of second dimension
618655
# Don't know if the reshaping should be done there or in calculate_score
619-
prediction = self.predict(X)
620-
if self._task == BINARY_CLASSIFICATION:
621-
prediction = prediction[:, 1].reshape((-1, 1))
656+
prediction = self.predict_proba(X)
622657
return calculate_score(y, prediction, self._task,
623658
self._metric, self._label_num,
624659
logger=self._logger)
625660

626661
def show_models(self):
627-
if self.models_ is None or len(self.models_) == 0 or len(
628-
self.ensemble_indices_) == 0:
629-
self._load_models()
630662

631-
output = []
632-
sio = six.StringIO()
633-
for identifier in self.models_:
634-
if identifier not in self.ensemble_indices_:
635-
continue
636-
637-
weight = self.ensemble_indices_[identifier]
638-
model = self.models_[identifier]
639-
output.append((weight, model))
640-
641-
output.sort(reverse=True)
642-
643-
sio.write("[")
644-
for weight, model in output:
645-
sio.write("(%f, %s),\n" % (weight, model))
646-
sio.write("]")
663+
if self.models_ is None or len(self.models_) == 0 or \
664+
self.ensemble_ is None:
665+
self._load_models()
647666

648-
return sio.getvalue()
667+
return self.ensemble_.pprint_ensemble_string(self.models_)
649668

650669
def _save_ensemble_data(self, X, y):
651670
"""Split dataset and store Data for the ensemble script.

autosklearn/cli/HPOlib_interface.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def parse_cli():
8282
return args, parameters
8383

8484

85-
def parse_args(dataset, mode, seed, params, fold, folds):
85+
def parse_args(dataset, mode, seed, params, fold, folds, output_dir=None):
8686
if seed is None:
8787
seed = 1
8888

@@ -107,10 +107,11 @@ def parse_args(dataset, mode, seed, params, fold, folds):
107107
mode_args = None
108108
else:
109109
raise ValueError(mode)
110-
base_interface.main(dataset, mode, seed, params, mode_args=mode_args)
110+
base_interface.main(dataset, mode, seed, params, mode_args=mode_args,
111+
output_dir=output_dir)
111112

112113

113-
def main():
114+
def main(output_dir=None):
114115
args, params = parse_cli()
115116
assert 'dataset' in args
116117
assert 'mode' in args
@@ -124,6 +125,7 @@ def main():
124125
params,
125126
int(args['fold']),
126127
int(args['folds']),
128+
output_dir=output_dir
127129
)
128130

129131

autosklearn/cli/SMAC_interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from autosklearn.cli import base_interface
55

6-
def main():
6+
7+
def main(output_dir=None):
78
instance_name = sys.argv[1]
89
instance_specific_information = sys.argv[2]
910
cutoff_time = float(sys.argv[3])
@@ -45,7 +46,7 @@ def main():
4546
raise ValueError(mode)
4647

4748
base_interface.main(instance_specific_information, mode,
48-
seed, params, mode_args=mode_args)
49+
seed, params, mode_args=mode_args, output_dir=output_dir)
4950

5051

5152
if __name__ == '__main__':

0 commit comments

Comments
 (0)