Skip to content

Commit 7b4702d

Browse files
authored
Merge pull request #338 from LGro/master
add method to get list of evaluated models (#336)
2 parents 7bd00bb + 561c64f commit 7b4702d

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

autosklearn/automl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,12 +764,23 @@ def sprint_statistics(self):
764764
'limit: %d\n' % num_memout)
765765
return sio.getvalue()
766766

767-
def show_models(self):
767+
def get_models_with_weights(self):
768768
if self.models_ is None or len(self.models_) == 0 or \
769769
self.ensemble_ is None:
770770
self._load_models()
771771

772-
return self.ensemble_.pprint_ensemble_string(self.models_)
772+
return self.ensemble_.get_models_with_weights(self.models_)
773+
774+
def show_models(self):
775+
models_with_weights = self.get_models_with_weights()
776+
777+
with io.StringIO() as sio:
778+
sio.write("[")
779+
for weight, model in models_with_weights:
780+
sio.write("(%f, %s),\n" % (weight, model))
781+
sio.write("]")
782+
783+
return sio.getvalue()
773784

774785
def _create_search_space(self, tmp_dir, backend, datamanager,
775786
include_estimators=None,

autosklearn/ensembles/abstract_ensemble.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def predict(self, base_models_predictions):
4242
self
4343

4444
@abstractmethod
45-
def pprint_ensemble_string(self, models):
46-
"""Return a nicely-readable representation of the ensmble.
45+
def get_models_with_weights(self, models):
46+
"""Return a list of (weight, model) pairs
4747
4848
Parameters
4949
----------
@@ -53,9 +53,10 @@ def pprint_ensemble_string(self, models):
5353
5454
Returns
5555
-------
56-
str
56+
array : [(weight_1, model_1), ..., (weight_n, model_n)]
5757
"""
5858

59+
5960
@abstractmethod
6061
def get_model_identifiers(self):
6162
"""Return identifiers of models in the ensemble.

autosklearn/ensembles/ensemble_selection.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import random
33

44
import numpy as np
5-
import six
65

76
from autosklearn.constants import *
87
from autosklearn.ensembles.abstract_ensemble import AbstractEnsemble
@@ -204,9 +203,9 @@ def __str__(self):
204203
enumerate(self.identifiers_)
205204
if self.weights_[idx] > 0]))
206205

207-
def pprint_ensemble_string(self, models):
206+
def get_models_with_weights(self, models):
208207
output = []
209-
sio = six.StringIO()
208+
210209
for i, weight in enumerate(self.weights_):
211210
identifier = self.identifiers_[i]
212211
model = models[identifier]
@@ -215,12 +214,7 @@ def pprint_ensemble_string(self, models):
215214

216215
output.sort(reverse=True, key=lambda t: t[0])
217216

218-
sio.write("[")
219-
for weight, model in output:
220-
sio.write("(%f, %s),\n" % (weight, model))
221-
sio.write("]")
222-
223-
return sio.getvalue()
217+
return output
224218

225219
def get_model_identifiers(self):
226220
return self.identifiers_

autosklearn/estimators.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ def show_models(self):
7373
"""
7474
return self._automl.show_models()
7575

76+
def get_models_with_weights(self):
77+
"""Return a list of the final ensemble found by auto-sklearn.
78+
79+
Returns
80+
-------
81+
[(weight_1, model_1), ..., (weight_n, model_n)]
82+
83+
"""
84+
return self._automl.get_models_with_weights()
85+
7686
@property
7787
def cv_results_(self):
7888
return self._automl.cv_results_

0 commit comments

Comments
 (0)