Skip to content

Commit 37c879f

Browse files
Ensemble in dask (#992)
* Ensemble with Dask * flake fix * More debug info on failure * Feedback from comments * Fix test log file * move to proc * Flake8 * Move to dask fixture * Move test to use the fixture * Increase run time to remove random crash * Address commit feedback * Fix outputs * test pytest fixtures * continue moving to pytest tests * close dask client and cluster * start local cluster differently * more debug output * run all tests again * replace close by shutdown * more debug output * shutdown and close clients * proactively delete dask client objects * fix fixture directory, reduce debug output * Incorporate feedback from comments * intermediate commit refactoring unit tests * Added pytest test to ensemble * Fixing dummy classifier merge conflict * additional msg to dict * Only one active ensemble * Feedback from comments * Added missing pickle test * Moving to thread based ensemble * sleep when needed * Minor changes to ensemble scheduling * use future.result() to wait for a future instead of active wait with sleep * store hash of ensemble training data in status pickle file * build ensemble via SMAC callback * fix ensemble time limit * bump SMAC requirement * PEP8 * fix tests? * further stabilize tests * robustify examples * improve unit tests Co-authored-by: chico <[email protected]>
1 parent c363cd6 commit 37c879f

27 files changed

+2866
-2123
lines changed

autosklearn/automl.py

Lines changed: 125 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# -*- encoding: utf-8 -*-
2+
import copy
23
import io
34
import json
4-
import multiprocessing
55
import platform
66
import os
77
import sys
8-
from typing import Optional, List, Union
8+
import time
9+
from typing import Any, Dict, Optional, List, Union
910
import unittest.mock
1011
import warnings
12+
import tempfile
1113

1214
from ConfigSpace.read_and_write import json as cs_json
1315
import dask.distributed
@@ -38,7 +40,7 @@
3840
from autosklearn.util.stopwatch import StopWatch
3941
from autosklearn.util.logging_ import get_logger, setup_logger
4042
from autosklearn.util import pipeline, RE_PATTERN
41-
from autosklearn.ensemble_builder import EnsembleBuilder
43+
from autosklearn.ensemble_builder import EnsembleBuilderManager
4244
from autosklearn.ensembles.singlebest_ensemble import SingleBest
4345
from autosklearn.smbo import AutoMLSMBO
4446
from autosklearn.util.hash import hash_array_or_matrix
@@ -100,9 +102,8 @@ def __init__(self,
100102
ensemble_size=1,
101103
ensemble_nbest=1,
102104
max_models_on_disc=1,
103-
ensemble_memory_limit: Optional[int] = 1024,
104105
seed=1,
105-
ml_memory_limit=3072,
106+
memory_limit=3072,
106107
metadata_directory=None,
107108
debug_mode=False,
108109
include_estimators=None,
@@ -131,9 +132,8 @@ def __init__(self,
131132
self._ensemble_size = ensemble_size
132133
self._ensemble_nbest = ensemble_nbest
133134
self._max_models_on_disc = max_models_on_disc
134-
self._ensemble_memory_limit = ensemble_memory_limit
135135
self._seed = seed
136-
self._ml_memory_limit = ml_memory_limit
136+
self._memory_limit = memory_limit
137137
self._data_memory_limit = None
138138
self._metadata_directory = metadata_directory
139139
self._include_estimators = include_estimators
@@ -171,6 +171,7 @@ def __init__(self,
171171
self._resampling_strategy_arguments['folds'] = 5
172172
self._n_jobs = n_jobs
173173
self._dask_client = dask_client
174+
174175
self.precision = precision
175176
self._disable_evaluator_output = disable_evaluator_output
176177
# Check arguments prior to doing anything!
@@ -207,8 +208,7 @@ def __init__(self,
207208

208209
self.InputValidator = InputValidator()
209210

210-
# Place holder for the run history of the
211-
# Ensemble building process
211+
# The ensemble performance history through time
212212
self.ensemble_performance_history = []
213213

214214
if not isinstance(self._time_for_task, int):
@@ -221,6 +221,40 @@ def __init__(self,
221221
# After assigning and checking variables...
222222
# self._backend = Backend(self._output_dir, self._tmp_dir)
223223

224+
def _create_dask_client(self):
225+
self._is_dask_client_internally_created = True
226+
processes = False
227+
if self._n_jobs is not None and self._n_jobs > 1:
228+
processes = True
229+
dask.config.set({'distributed.worker.daemon': False})
230+
self._dask_client = dask.distributed.Client(
231+
dask.distributed.LocalCluster(
232+
n_workers=self._n_jobs,
233+
processes=processes,
234+
threads_per_worker=1,
235+
# We use the temporal directory to save the
236+
# dask workers, because deleting workers
237+
# more time than deleting backend directories
238+
# This prevent an error saying that the worker
239+
# file was deleted, so the client could not close
240+
# the worker properly
241+
local_directory=tempfile.gettempdir(),
242+
)
243+
)
244+
245+
def _close_dask_client(self):
246+
if (
247+
hasattr(self, '_is_dask_client_internally_created')
248+
and self._is_dask_client_internally_created
249+
and self._dask_client
250+
):
251+
self._dask_client.shutdown()
252+
self._dask_client.close()
253+
del self._dask_client
254+
self._dask_client = None
255+
self._is_dask_client_internally_created = False
256+
del self._is_dask_client_internally_created
257+
224258
def _get_logger(self, name):
225259
logger_name = 'AutoML(%d):%s' % (self._seed, name)
226260
setup_logger(os.path.join(self._backend.temporary_directory,
@@ -256,7 +290,7 @@ def _do_dummy_prediction(self, datamanager, num_run):
256290

257291
self._logger.info("Starting to create dummy predictions.")
258292

259-
memory_limit = self._ml_memory_limit
293+
memory_limit = self._memory_limit
260294
if memory_limit is not None:
261295
memory_limit = int(memory_limit)
262296

@@ -350,6 +384,16 @@ def fit(
350384
raise ValueError('Metric must be instance of '
351385
'autosklearn.metrics.Scorer.')
352386

387+
# If no dask client was provided, we create one, so that we can
388+
# start a ensemble process in parallel to smbo optimize
389+
if (
390+
self._dask_client is None and
391+
(self._ensemble_size > 0 or self._n_jobs is not None and self._n_jobs > 1)
392+
):
393+
self._create_dask_client()
394+
else:
395+
self._is_dask_client_internally_created = False
396+
353397
if dataset_name is None:
354398
dataset_name = hash_array_or_matrix(X)
355399

@@ -425,9 +469,8 @@ def fit(
425469
self._logger.debug(' ensemble_size: %d', self._ensemble_size)
426470
self._logger.debug(' ensemble_nbest: %f', self._ensemble_nbest)
427471
self._logger.debug(' max_models_on_disc: %s', str(self._max_models_on_disc))
428-
self._logger.debug(' ensemble_memory_limit: %d', self._ensemble_memory_limit)
429472
self._logger.debug(' seed: %d', self._seed)
430-
self._logger.debug(' ml_memory_limit: %s', str(self._ml_memory_limit))
473+
self._logger.debug(' memory_limit: %s', str(self._memory_limit))
431474
self._logger.debug(' metadata_directory: %s', self._metadata_directory)
432475
self._logger.debug(' debug_mode: %s', self._debug_mode)
433476
self._logger.debug(' include_estimators: %s', str(self._include_estimators))
@@ -512,6 +555,7 @@ def fit(
512555
include_preprocessors=self._include_preprocessors,
513556
exclude_preprocessors=self._exclude_preprocessors)
514557
if only_return_configuration_space:
558+
self._close_dask_client()
515559
return self.configuration_space
516560

517561
# == RUN ensemble builder
@@ -522,35 +566,39 @@ def fit(
522566
self._stopwatch.start_task(ensemble_task_name)
523567
elapsed_time = self._stopwatch.wall_elapsed(self._dataset_name)
524568
time_left_for_ensembles = max(0, self._time_for_task - elapsed_time)
569+
proc_ensemble = None
525570
if time_left_for_ensembles <= 0:
526-
self._proc_ensemble = None
527571
# Fit only raises error when ensemble_size is not zero but
528572
# time_left_for_ensembles is zero.
529573
if self._ensemble_size > 0:
530574
raise ValueError("Not starting ensemble builder because there "
531575
"is no time left. Try increasing the value "
532576
"of time_left_for_this_task.")
533577
elif self._ensemble_size <= 0:
534-
self._proc_ensemble = None
535578
self._logger.info('Not starting ensemble builder because '
536579
'ensemble size is <= 0.')
537580
else:
538581
self._logger.info(
539582
'Start Ensemble with %5.2fsec time left' % time_left_for_ensembles)
540583

541-
# Create a queue to communicate with the ensemble process
542-
# And get the run history
543-
# Use a Manager as a workaround to memory errors cause
544-
# by three subprocesses (Automl-ensemble_builder-pynisher)
545-
mgr = multiprocessing.Manager()
546-
mgr.Namespace()
547-
queue = mgr.Queue()
548-
549-
self._proc_ensemble = self._get_ensemble_process(
550-
time_left_for_ensembles,
551-
queue=queue,
584+
proc_ensemble = EnsembleBuilderManager(
585+
start_time=time.time(),
586+
time_left_for_ensembles=time_left_for_ensembles,
587+
backend=copy.deepcopy(self._backend),
588+
dataset_name=dataset_name,
589+
task=task,
590+
metric=self._metric,
591+
ensemble_size=self._ensemble_size,
592+
ensemble_nbest=self._ensemble_nbest,
593+
max_models_on_disc=self._max_models_on_disc,
594+
seed=self._seed,
595+
precision=self.precision,
596+
max_iterations=None,
597+
read_at_most=np.inf,
598+
ensemble_memory_limit=self._memory_limit,
599+
logger_name=self._logger.name,
600+
random_state=self._seed,
552601
)
553-
self._proc_ensemble.start()
554602

555603
self._stopwatch.stop_task(ensemble_task_name)
556604

@@ -603,7 +651,7 @@ def fit(
603651
backend=self._backend,
604652
total_walltime_limit=time_left_for_smac,
605653
func_eval_time_limit=per_run_time_limit,
606-
memory_limit=self._ml_memory_limit,
654+
memory_limit=self._memory_limit,
607655
data_memory_limit=self._data_memory_limit,
608656
watcher=self._stopwatch,
609657
n_jobs=self._n_jobs,
@@ -623,6 +671,7 @@ def fit(
623671
disable_file_output=self._disable_evaluator_output,
624672
get_smac_object_callback=self._get_smac_object_callback,
625673
smac_scenario_args=self._smac_scenario_args,
674+
ensemble_callback=proc_ensemble,
626675
)
627676

628677
try:
@@ -642,14 +691,15 @@ def fit(
642691

643692
# Wait until the ensemble process is finished to avoid shutting down
644693
# while the ensemble builder tries to access the data
645-
if self._proc_ensemble is not None and self._ensemble_size > 0:
646-
self._proc_ensemble.join()
647-
self.ensemble_performance_history = self._proc_ensemble.get_ensemble_history()
648-
649-
self._proc_ensemble = None
694+
if proc_ensemble is not None:
695+
self.ensemble_performance_history = list(proc_ensemble.history)
696+
if len(proc_ensemble.futures) > 0:
697+
future = proc_ensemble.futures.pop()
698+
future.cancel()
650699

651700
if load_models:
652701
self._load_models()
702+
self._close_dask_client()
653703

654704
return self
655705

@@ -798,62 +848,42 @@ def fit_ensemble(self, y, task=None, precision=32,
798848
# Make sure that input is valid
799849
y = self.InputValidator.validate_target(y, is_classification=True)
800850

801-
self._proc_ensemble = self._get_ensemble_process(
802-
1, task, precision, dataset_name, max_iterations=1,
803-
ensemble_nbest=ensemble_nbest, ensemble_size=ensemble_size)
804-
self._proc_ensemble.main()
805-
self.ensemble_performance_history = self._proc_ensemble.get_ensemble_history()
806-
self._proc_ensemble = None
807-
self._load_models()
808-
return self
809-
810-
def _get_ensemble_process(self, time_left_for_ensembles,
811-
task=None, precision=None,
812-
dataset_name=None, max_iterations=None,
813-
ensemble_nbest=None, ensemble_size=None, queue=None):
814-
815-
if task is None:
816-
task = self._task
817-
else:
818-
self._task = task
819-
820-
if precision is None:
821-
precision = self.precision
822-
else:
823-
self.precision = precision
824-
825-
if dataset_name is None:
826-
dataset_name = self._dataset_name
827-
else:
828-
self._dataset_name = dataset_name
829-
830-
if ensemble_nbest is None:
831-
ensemble_nbest = self._ensemble_nbest
832-
else:
833-
self._ensemble_nbest = ensemble_nbest
834-
835-
if ensemble_size is None:
836-
ensemble_size = self._ensemble_size
851+
# Create a client if needed
852+
if self._dask_client is None:
853+
self._create_dask_client()
837854
else:
838-
self._ensemble_size = ensemble_size
839-
840-
return EnsembleBuilder(
841-
backend=self._backend,
842-
dataset_name=dataset_name,
843-
task_type=task,
855+
self._is_dask_client_internally_created = False
856+
857+
# Use the current thread to start the ensemble builder process
858+
# The function ensemble_builder_process will internally create a ensemble
859+
# builder in the provide dask client
860+
manager = EnsembleBuilderManager(
861+
start_time=time.time(),
862+
time_left_for_ensembles=self._time_for_task,
863+
backend=copy.deepcopy(self._backend),
864+
dataset_name=dataset_name if dataset_name else self._dataset_name,
865+
task=task if task else self._task,
844866
metric=self._metric,
845-
limit=time_left_for_ensembles,
846-
ensemble_size=ensemble_size,
847-
ensemble_nbest=ensemble_nbest,
867+
ensemble_size=ensemble_size if ensemble_size else self._ensemble_size,
868+
ensemble_nbest=ensemble_nbest if ensemble_nbest else self._ensemble_nbest,
848869
max_models_on_disc=self._max_models_on_disc,
849870
seed=self._seed,
850-
precision=precision,
851-
max_iterations=max_iterations,
871+
precision=precision if precision else self.precision,
872+
max_iterations=1,
852873
read_at_most=np.inf,
853-
memory_limit=self._ensemble_memory_limit,
874+
ensemble_memory_limit=self._memory_limit,
854875
random_state=self._seed,
855-
queue=queue,
876+
logger_name=self._logger.name,
856877
)
878+
manager.build_ensemble(self._dask_client)
879+
future = manager.futures.pop()
880+
dask.distributed.wait([future]) # wait for the ensemble process to finish
881+
result = future.result()
882+
self.ensemble_performance_history, _ = result
883+
884+
self._load_models()
885+
self._close_dask_client()
886+
return self
857887

858888
def _load_models(self):
859889
self.ensemble_ = self._backend.load_ensemble(self._seed)
@@ -914,6 +944,7 @@ def _load_best_individual_model(self):
914944
metric=self._metric,
915945
random_state=self._seed,
916946
run_history=self.runhistory_,
947+
model_dir=self._backend.get_model_dir(),
917948
)
918949
self._logger.warning(
919950
"No valid ensemble was created. Please check the log"
@@ -1127,6 +1158,19 @@ def _create_search_space(self, tmp_dir, backend, datamanager,
11271158
def configuration_space_created_hook(self, datamanager, configuration_space):
11281159
return configuration_space
11291160

1161+
def __getstate__(self) -> Dict[str, Any]:
1162+
# Cannot serialize a client!
1163+
self._dask_client = None
1164+
return self.__dict__
1165+
1166+
def __del__(self):
1167+
self._close_dask_client()
1168+
1169+
# When a multiprocessing work is done, the
1170+
# objects are deleted. We don't want to delete run areas
1171+
# until the estimator is deleted
1172+
self._backend.context.delete_directories(force=False)
1173+
11301174

11311175
class AutoMLClassifier(AutoML):
11321176
def __init__(self, *args, **kwargs):

0 commit comments

Comments
 (0)