Skip to content

Commit d3aafbd

Browse files
authored
Improve ensemble selection memory usage (#997)
* Improve ensemble selection memory usage * separate storage of data inside the ensemble builder into two dictionaries to separately store them on disk (one for scores and one for predictions). If we run over the allocated memory, we can still make use of the stored scores * do not stop the ensemble builder when the number of models to consider can no longer be reduced so that the ensemble builder can still delete models from the hard drive if necessary. * avoid unnecessary memory copies during ensemble construction * improve structure of temporary directories during unit testing * delete files before building an ensemble * flake8 * reorder function calls in ensemble builder
1 parent 9952f67 commit d3aafbd

File tree

14 files changed

+461
-333
lines changed

14 files changed

+461
-333
lines changed

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ matrix:
2323
- os: linux
2424
env: DISTRIB="conda" DOCPUSH="true" PYTHON="3.7" SKIP_TESTS="true"
2525
- os: linux
26-
env: DISTRIB="conda" RUN_FLAKE8="true" SKIP_TESTS="true"
26+
env: DISTRIB="conda" PYTHON="3.8" RUN_FLAKE8="true" SKIP_TESTS="true"
2727
- os: linux
28-
env: DISTRIB="conda" RUN_MYPY="true" SKIP_TESTS="true"
28+
env: DISTRIB="conda" PYTHON="3.8" RUN_MYPY="true" SKIP_TESTS="true"
2929
- os: linux
3030
env: DISTRIB="conda" COVERAGE="true" PYTHON="3.6"
3131
- os: linux

autosklearn/automl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,10 @@ def fit_ensemble(self, y, task=None, precision=32,
872872
future = manager.futures.pop()
873873
dask.distributed.wait([future]) # wait for the ensemble process to finish
874874
result = future.result()
875-
self.ensemble_performance_history, _ = result
875+
if result is None:
876+
raise ValueError("Error building the ensemble - please check the log file and command "
877+
"line output for error messages.")
878+
self.ensemble_performance_history, _, _, _, _ = result
876879

877880
self._load_models()
878881
self._close_dask_client()

autosklearn/ensemble_builder.py

Lines changed: 190 additions & 135 deletions
Large diffs are not rendered by default.

autosklearn/ensembles/abstract_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABCMeta, abstractmethod
2-
from typing import Dict, List, Tuple
2+
from typing import Dict, List, Tuple, Union
33

44
import numpy as np
55

@@ -40,7 +40,7 @@ def fit(
4040
pass
4141

4242
@abstractmethod
43-
def predict(self, base_models_predictions: np.ndarray) -> np.ndarray:
43+
def predict(self, base_models_predictions: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
4444
"""Create ensemble predictions from the base model predictions.
4545
4646
Parameters

autosklearn/ensembles/ensemble_selection.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import Counter
3-
from typing import Any, Dict, List, Tuple, cast
3+
from typing import Any, Dict, List, Tuple, Union, cast
44

55
import numpy as np
66

@@ -265,27 +265,32 @@ def _bagging(
265265
dtype=np.int64,
266266
)
267267

268-
def predict(self, predictions: np.ndarray) -> np.ndarray:
269-
predictions = np.asarray(
270-
predictions,
271-
dtype=np.float64,
272-
)
268+
def predict(self, predictions: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
269+
270+
average = np.zeros_like(predictions[0], dtype=np.float64)
271+
tmp_predictions = np.empty_like(predictions[0], dtype=np.float64)
273272

274273
# if predictions.shape[0] == len(self.weights_),
275274
# predictions include those of zero-weight models.
276-
if predictions.shape[0] == len(self.weights_):
277-
return np.average(predictions, axis=0, weights=self.weights_)
275+
if len(predictions) == len(self.weights_):
276+
for pred, weight in zip(predictions, self.weights_):
277+
np.multiply(pred, weight, out=tmp_predictions)
278+
np.add(average, tmp_predictions, out=average)
278279

279280
# if prediction model.shape[0] == len(non_null_weights),
280281
# predictions do not include those of zero-weight models.
281-
elif predictions.shape[0] == np.count_nonzero(self.weights_):
282+
elif len(predictions) == np.count_nonzero(self.weights_):
282283
non_null_weights = [w for w in self.weights_ if w > 0]
283-
return np.average(predictions, axis=0, weights=non_null_weights)
284+
for pred, weight in zip(predictions, non_null_weights):
285+
np.multiply(pred, weight, out=tmp_predictions)
286+
np.add(average, tmp_predictions, out=average)
284287

285288
# If none of the above applies, then something must have gone wrong.
286289
else:
287290
raise ValueError("The dimensions of ensemble predictions"
288291
" and ensemble weights do not match!")
292+
del tmp_predictions
293+
return average
289294

290295
def __str__(self) -> str:
291296
return 'Ensemble Selection:\n\tTrajectory: %s\n\tMembers: %s' \

autosklearn/ensembles/singlebest_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Union
33

44
import numpy as np
55

@@ -85,7 +85,7 @@ def get_identifiers_from_run_history(self) -> List[Tuple[int, int, float]]:
8585

8686
return best_model_identifier
8787

88-
def predict(self, predictions: np.ndarray) -> np.ndarray:
88+
def predict(self, predictions: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
8989
return predictions[0]
9090

9191
def __str__(self) -> str:

autosklearn/estimators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
logging_config=None,
4343
metadata_directory=None,
4444
metric=None,
45+
load_models: bool = True,
4546
):
4647
"""
4748
Parameters
@@ -216,6 +217,9 @@ def __init__(
216217
:meth:`autosklearn.metrics.make_scorer`. These are the `Built-in
217218
Metrics`_.
218219
If None is provided, a default metric is selected depending on the task.
220+
221+
load_models : bool, optional (True)
222+
Whether to load the models after fitting Auto-sklearn.
219223
220224
Attributes
221225
----------
@@ -257,6 +261,7 @@ def __init__(
257261
self.logging_config = logging_config
258262
self.metadata_directory = metadata_directory
259263
self._metric = metric
264+
self._load_models = load_models
260265

261266
self.automl_ = None # type: Optional[AutoML]
262267
# n_jobs after conversion to a number (b/c default is None)
@@ -340,7 +345,7 @@ def fit(self, **kwargs):
340345
tmp_folder=self.tmp_folder,
341346
output_folder=self.output_folder,
342347
)
343-
self.automl_.fit(load_models=True, **kwargs)
348+
self.automl_.fit(load_models=self._load_models, **kwargs)
344349

345350
return self
346351

test/test_automl/test_automl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def test_automl_outputs(backend, dask_client):
305305
'start_time_100',
306306
'datamanager.pkl',
307307
'ensemble_read_preds.pkl',
308+
'ensemble_read_scores.pkl',
308309
'runs',
309310
'ensembles',
310311
]

test/test_ensemble_builder/ensemble_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import os
2+
import shutil
23
import unittest
34

4-
55
import numpy as np
66

77
from autosklearn.metrics import make_scorer
88
from autosklearn.ensemble_builder import (
9-
EnsembleBuilder,
9+
EnsembleBuilder, AbstractEnsemble
1010
)
1111

12-
this_directory = os.path.dirname(__file__)
13-
1412

1513
def scorer_function(a, b):
1614
return 0.9
@@ -21,22 +19,19 @@ def scorer_function(a, b):
2119

2220
class BackendMock(object):
2321

24-
def __init__(self):
22+
def __init__(self, target_directory):
2523
this_directory = os.path.abspath(
2624
os.path.dirname(__file__)
2725
)
28-
self.temporary_directory = os.path.join(
29-
this_directory, 'data',
30-
)
31-
self.internals_directory = os.path.join(
32-
this_directory, 'data', '.auto-sklearn',
33-
)
26+
shutil.copytree(os.path.join(this_directory, 'data'), os.path.join(target_directory))
27+
self.temporary_directory = target_directory
28+
self.internals_directory = os.path.join(self.temporary_directory, '.auto-sklearn')
3429

3530
def load_datamanager(self):
3631
manager = unittest.mock.Mock()
3732
manager.__reduce__ = lambda self: (unittest.mock.MagicMock, ())
3833
array = np.load(os.path.join(
39-
this_directory, 'data',
34+
self.temporary_directory,
4035
'.auto-sklearn',
4136
'runs', '0_3_100.0',
4237
'predictions_test_0_3_100.0.npy'
@@ -60,7 +55,7 @@ def save_predictions_as_txt(self, predictions, subset, idx, prefix, precision):
6055
return
6156

6257
def get_runs_directory(self) -> str:
63-
return os.path.join(this_directory, 'data', '.auto-sklearn', 'runs')
58+
return os.path.join(self.temporary_directory, '.auto-sklearn', 'runs')
6459

6560
def get_numrun_directory(self, seed: int, num_run: int, budget: float) -> str:
6661
return os.path.join(self.get_runs_directory(), '%d_%d_%s' % (seed, num_run, budget))
@@ -97,4 +92,11 @@ def compare_read_preds(read_preds1, read_preds2):
9792
class EnsembleBuilderMemMock(EnsembleBuilder):
9893

9994
def fit_ensemble(self, selected_keys):
95+
return True
96+
97+
def predict(self, set_: str,
98+
ensemble: AbstractEnsemble,
99+
selected_keys: list,
100+
n_preds: int,
101+
index_run: int):
100102
np.ones([10000000, 1000000])

0 commit comments

Comments
 (0)