Skip to content

Commit 26760aa

Browse files
authored
Threads and forkserver (#1062)
* use threads again * try the forkserver * use forkserver pre-load for faster process starting * streamline code * de-duplicate code * add missing file * Update parallel.py * Update parallel.py
1 parent d96f9ce commit 26760aa

File tree

8 files changed

+42
-15
lines changed

8 files changed

+42
-15
lines changed

autosklearn/automl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
get_named_client_logger,
5050
)
5151
from autosklearn.util import pipeline, RE_PATTERN
52+
from autosklearn.util.parallel import preload_modules
5253
from autosklearn.ensemble_builder import EnsembleBuilderManager
5354
from autosklearn.ensembles.singlebest_ensemble import SingleBest
5455
from autosklearn.smbo import AutoMLSMBO
@@ -228,7 +229,7 @@ def __init__(self,
228229
# examples. Nevertheless, multi-process runs
229230
# have spawn as requirement to reduce the
230231
# possibility of a deadlock
231-
self._multiprocessing_context = 'spawn'
232+
self._multiprocessing_context = 'forkserver'
232233
if self._n_jobs == 1 and self._dask_client is None:
233234
self._multiprocessing_context = 'fork'
234235
self._dask_client = SingleThreadedClient()
@@ -248,11 +249,10 @@ def __init__(self,
248249

249250
def _create_dask_client(self):
250251
self._is_dask_client_internally_created = True
251-
dask.config.set({'distributed.worker.daemon': False})
252252
self._dask_client = dask.distributed.Client(
253253
dask.distributed.LocalCluster(
254254
n_workers=self._n_jobs,
255-
processes=True if self._n_jobs != 1 else False,
255+
processes=False,
256256
threads_per_worker=1,
257257
# We use the temporal directory to save the
258258
# dask workers, because deleting workers
@@ -299,8 +299,8 @@ def _get_logger(self, name):
299299
# under the above logging configuration setting
300300
# We need to specify the logger_name so that received records
301301
# are treated under the logger_name ROOT logger setting
302-
context = multiprocessing.get_context(
303-
self._multiprocessing_context)
302+
context = multiprocessing.get_context(self._multiprocessing_context)
303+
preload_modules(context)
304304
self.stop_logging_server = context.Event()
305305
port = context.Value('l') # be safe by using a long
306306
port.value = -1

autosklearn/ensemble_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from autosklearn.ensembles.ensemble_selection import EnsembleSelection
3232
from autosklearn.ensembles.abstract_ensemble import AbstractEnsemble
3333
from autosklearn.util.logging_ import get_named_client_logger
34+
from autosklearn.util.parallel import preload_modules
3435

3536
Y_ENSEMBLE = 0
3637
Y_VALID = 1
@@ -572,11 +573,11 @@ def __init__(
572573
def run(
573574
self,
574575
iteration: int,
576+
pynisher_context: str,
575577
time_left: Optional[float] = None,
576578
end_at: Optional[float] = None,
577579
time_buffer=5,
578580
return_predictions: bool = False,
579-
pynisher_context: str = 'spawn',
580581
):
581582

582583
if time_left is None and end_at is None:
@@ -606,6 +607,7 @@ def run(
606607
if wall_time_in_s < 1:
607608
break
608609
context = multiprocessing.get_context(pynisher_context)
610+
preload_modules(context)
609611

610612
safe_ensemble_script = pynisher.enforce_limits(
611613
wall_time_in_s=wall_time_in_s,

autosklearn/evaluation/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import autosklearn.evaluation.test_evaluator
2525
import autosklearn.evaluation.util
2626
from autosklearn.util.logging_ import get_named_client_logger
27+
from autosklearn.util.parallel import preload_modules
2728

2829

2930
def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
@@ -97,12 +98,12 @@ def _encode_exit_status(exit_status):
9798
class ExecuteTaFuncWithQueue(AbstractTAFunc):
9899

99100
def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
100-
cost_for_crash, abort_on_first_run_crash, port,
101+
cost_for_crash, abort_on_first_run_crash, port, pynisher_context,
101102
initial_num_run=1, stats=None,
102103
run_obj='quality', par_factor=1, scoring_functions=None,
103104
output_y_hat_optimization=True, include=None, exclude=None,
104105
memory_limit=None, disable_file_output=False, init_params=None,
105-
budget_type=None, ta=False, pynisher_context='spawn', **resampling_strategy_args):
106+
budget_type=None, ta=False, **resampling_strategy_args):
106107

107108
if resampling_strategy == 'holdout':
108109
eval_function = autosklearn.evaluation.train_evaluator.eval_holdout
@@ -261,6 +262,7 @@ def run(
261262
) -> Tuple[StatusType, float, float, Dict[str, Union[int, float, str, Dict, List, Tuple]]]:
262263

263264
context = multiprocessing.get_context(self.pynisher_context)
265+
preload_modules(context)
264266
queue = context.Queue()
265267

266268
if not (instance_specific is None or instance_specific == '0'):

autosklearn/util/parallel.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import multiprocessing
2+
import sys
3+
4+
5+
def preload_modules(context: multiprocessing.context.BaseContext) -> None:
6+
all_loaded_modules = sys.modules.keys()
7+
preload = [
8+
loaded_module for loaded_module in all_loaded_modules
9+
if loaded_module.split('.')[0] in (
10+
'smac',
11+
'autosklearn',
12+
'numpy',
13+
'scipy',
14+
'pandas',
15+
'pynisher',
16+
'sklearn',
17+
'ConfigSpace',
18+
) and 'logging' not in loaded_module
19+
]
20+
context.set_forkserver_preload(preload)

scripts/run_auto-sklearn_for_metadata_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@
151151
include=include,
152152
metric=automl_arguments['metric'],
153153
cost_for_crash=get_cost_of_crash(automl_arguments['metric']),
154-
abort_on_first_run_crash=False,)
154+
abort_on_first_run_crash=False,
155+
pynisher_context='fork')
155156
run_info, run_value = ta.run_wrapper(
156157
RunInfo(
157158
config=config,

test/conftest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import time
44
import unittest.mock
55

6-
import dask
76
from dask.distributed import Client, get_client
87
import psutil
98
import pytest
@@ -125,8 +124,7 @@ def dask_client(request):
125124
Workers are in subprocesses to not create deadlocks with the pynisher and logging.
126125
"""
127126

128-
dask.config.set({'distributed.worker.daemon': False})
129-
client = Client(n_workers=2, threads_per_worker=1, processes=True)
127+
client = Client(n_workers=2, threads_per_worker=1, processes=False)
130128
print("Started Dask client={}\n".format(client))
131129

132130
def get_finalizer(address):
@@ -151,7 +149,6 @@ def dask_client_single_worker(request):
151149
it is used very rarely to avoid this issue as much as possible.
152150
"""
153151

154-
dask.config.set({'distributed.worker.daemon': False})
155152
client = Client(n_workers=1, threads_per_worker=1, processes=False)
156153
print("Started Dask client={}\n".format(client))
157154

test/test_ensemble_builder/test_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def test_run_end_at(ensemble_backend):
504504

505505
current_time = time.time()
506506

507-
ensbuilder.run(end_at=current_time + 10, iteration=1)
507+
ensbuilder.run(end_at=current_time + 10, iteration=1, pynisher_context='forkserver')
508508
# 4 seconds left because: 10 seconds - 5 seconds overhead - very little overhead,
509509
# but then rounded to an integer
510510
assert pynisher_mock.call_args_list[0][1]["wall_time_in_s"], 4
@@ -579,7 +579,7 @@ def mtime_mock(filename):
579579

580580
# And then it still runs, but basically won't do anything any more except for raising error
581581
# messages via the logger
582-
ensbuilder.run(time_left=1000, iteration=0)
582+
ensbuilder.run(time_left=1000, iteration=0, pynisher_context='fork')
583583
assert os.path.exists(read_scores_file)
584584
assert not os.path.exists(read_preds_file)
585585
assert logger_mock.warning.call_count == 4

test/test_evaluation/test_evaluation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_zero_or_negative_cutoff(self, pynisher_mock):
112112
metric=accuracy,
113113
cost_for_crash=get_cost_of_crash(accuracy),
114114
abort_on_first_run_crash=False,
115+
pynisher_context='forkserver',
115116
)
116117
self.scenario.wallclock_limit = 5
117118
self.stats.submitted_ta_runs += 1
@@ -130,6 +131,7 @@ def test_cutoff_lower_than_remaining_time(self, pynisher_mock):
130131
metric=accuracy,
131132
cost_for_crash=get_cost_of_crash(accuracy),
132133
abort_on_first_run_crash=False,
134+
pynisher_context='forkserver',
133135
)
134136
self.stats.ta_runs = 1
135137
ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None, instance_specific=None,
@@ -224,6 +226,7 @@ def test_eval_with_limits_holdout_fail_timeout(self, pynisher_mock):
224226
metric=accuracy,
225227
cost_for_crash=get_cost_of_crash(accuracy),
226228
abort_on_first_run_crash=False,
229+
pynisher_context='forkserver',
227230
)
228231
info = ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None,
229232
instance_specific=None, seed=1, capped=False))
@@ -259,6 +262,7 @@ def side_effect(**kwargs):
259262
metric=accuracy,
260263
cost_for_crash=get_cost_of_crash(accuracy),
261264
abort_on_first_run_crash=False,
265+
pynisher_context='forkserver',
262266
)
263267
info = ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None,
264268
instance_specific=None, seed=1, capped=False))
@@ -282,6 +286,7 @@ def side_effect(**kwargs):
282286
metric=accuracy,
283287
cost_for_crash=get_cost_of_crash(accuracy),
284288
abort_on_first_run_crash=False,
289+
pynisher_context='forkserver',
285290
)
286291
info = ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None,
287292
instance_specific=None, seed=1, capped=False))

0 commit comments

Comments
 (0)