diff --git a/joblibspark/backend.py b/joblibspark/backend.py index 05d9374..b86b47b 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -17,6 +17,10 @@ """ The joblib spark backend implementation. """ + +# pylint: disable=W0621,W1202 +import sys +import logging import warnings from multiprocessing.pool import ThreadPool import uuid @@ -32,6 +36,20 @@ from pyspark.util import VersionUtils +def _get_logger(name): + """ Gets a logger by name, or creates and configures it for the first time. """ + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + # If the logger is configured, skip the configure + if not logger.handlers and not logging.getLogger().handlers: + handler = logging.StreamHandler(sys.stderr) + logger.addHandler(handler) + return logger + + +logger = _get_logger("joblib-spark") + + def register(): """ Register joblib spark backend. @@ -56,6 +74,8 @@ class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin): Each task batch will be run inside one spark task on worker node, and will be executed by `SequentialBackend` """ + # Hard cap on the number of concurrent hyperopt tasks (Spark jobs) to run. Set at 128. + MAX_CONCURRENT_JOBS_ALLOWED = 128 def __init__(self, **backend_args): super(SparkDistributedBackend, self).__init__(**backend_args) @@ -77,20 +97,67 @@ def _cancel_all_jobs(self): else: self._spark.sparkContext.cancelJobGroup(self._job_group) + @staticmethod + def _decide_parallelism(requested_parallelism, + spark_default_parallelism, + max_num_concurrent_tasks): + """ + Given the requested parallelism, return the max parallelism SparkTrials will actually use. + See the docstring for `parallelism` in the constructor for expected behavior. + """ + if max_num_concurrent_tasks == 0: + logger.warning( + "The cluster has no executors currently. " + "The trials won't start until some new executors register." + ) + if requested_parallelism is None: + parallelism = 1 + elif requested_parallelism <= 0: + parallelism = max(spark_default_parallelism, max_num_concurrent_tasks, 1) + logger.warning( + "Because the requested parallelism was None or a non-positive value, " + "parallelism will be set to ({d}), which is Spark's default parallelism ({s}), " + "or the current total of Spark task slots ({t}), or 1, whichever is greater. " + "We recommend setting parallelism explicitly to a positive value because " + "the total of Spark task slots is subject to cluster sizing.".format( + d=parallelism, + s=spark_default_parallelism, + t=max_num_concurrent_tasks, + ) + ) + else: + parallelism = requested_parallelism + + if parallelism > SparkDistributedBackend.MAX_CONCURRENT_JOBS_ALLOWED: + logger.warning( + "Parallelism ({p}) is capped at SparkTrials.MAX_CONCURRENT_JOBS_ALLOWED ({c})." + .format(p=parallelism, c=SparkDistributedBackend.MAX_CONCURRENT_JOBS_ALLOWED) + ) + parallelism = SparkDistributedBackend.MAX_CONCURRENT_JOBS_ALLOWED + + if parallelism > max_num_concurrent_tasks: + logger.warning( + "Parallelism ({p}) is greater than the current total of Spark task slots ({c}). " + "If dynamic allocation is enabled, you might see more executors allocated.".format( + p=requested_parallelism, c=max_num_concurrent_tasks + ) + ) + return parallelism + def effective_n_jobs(self, n_jobs): + """ + n_jobs is None will request 1 worker. + n_jobs=-1 means requesting all available workers, + but if cluster in dynamic allocation mode and available workers is zero + then use spark_default_parallelism and trigger spark worker dynamic allocation + """ max_num_concurrent_tasks = self._get_max_num_concurrent_tasks() - if n_jobs is None: - n_jobs = 1 - elif n_jobs == -1: - # n_jobs=-1 means requesting all available workers - n_jobs = max_num_concurrent_tasks - if n_jobs > max_num_concurrent_tasks: - warnings.warn("User-specified n_jobs ({n}) is greater than the max number of " - "concurrent tasks ({c}) this cluster can run now. If dynamic " - "allocation is enabled for the cluster, you might see more " - "executors allocated." - .format(n=n_jobs, c=max_num_concurrent_tasks)) - return n_jobs + spark_default_parallelism = self._spark.sparkContext.defaultParallelism + return self._decide_parallelism( + requested_parallelism=n_jobs, + spark_default_parallelism=spark_default_parallelism, + max_num_concurrent_tasks=max_num_concurrent_tasks, + ) def _get_max_num_concurrent_tasks(self): # maxNumConcurrentTasks() is a package private API diff --git a/test/test_backend.py b/test/test_backend.py index d000f06..52e113f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -1,9 +1,26 @@ -import warnings +import contextlib +import logging from unittest.mock import MagicMock +from six import StringIO from joblibspark.backend import SparkDistributedBackend +@contextlib.contextmanager +def patch_logger(name, level=logging.INFO): + """patch logger and give an output""" + io_out = StringIO() + log = logging.getLogger(name) + log.setLevel(level) + log.handlers = [] + handler = logging.StreamHandler(io_out) + log.addHandler(handler) + try: + yield io_out + finally: + log.removeHandler(handler) + + def test_effective_n_jobs(): backend = SparkDistributedBackend() @@ -13,8 +30,49 @@ def test_effective_n_jobs(): assert backend.effective_n_jobs(n_jobs=None) == 1 assert backend.effective_n_jobs(n_jobs=-1) == 8 assert backend.effective_n_jobs(n_jobs=4) == 4 + assert backend.effective_n_jobs(n_jobs=16) == 16 + + +def test_parallelism_arg(): + for spark_default_parallelism, max_num_concurrent_tasks in [(2, 4), (2, 0)]: + default_parallelism = max(spark_default_parallelism, max_num_concurrent_tasks) + + assert 1 == SparkDistributedBackend._decide_parallelism( + requested_parallelism=None, + spark_default_parallelism=spark_default_parallelism, + max_num_concurrent_tasks=max_num_concurrent_tasks, + ) + with patch_logger("joblib-spark") as output: + parallelism = SparkDistributedBackend._decide_parallelism( + requested_parallelism=-1, + spark_default_parallelism=spark_default_parallelism, + max_num_concurrent_tasks=max_num_concurrent_tasks, + ) + assert parallelism == default_parallelism + log_output = output.getvalue().strip() + assert "Because the requested parallelism was None or a non-positive value, " \ + "parallelism will be set to ({d})".format(d=default_parallelism) in log_output + + # Test requested_parallelism which will trigger spark executor dynamic allocation. + with patch_logger("joblib-spark") as output: + parallelism = SparkDistributedBackend._decide_parallelism( + requested_parallelism=max_num_concurrent_tasks + 1, + spark_default_parallelism=spark_default_parallelism, + max_num_concurrent_tasks=max_num_concurrent_tasks, + ) + assert parallelism == max_num_concurrent_tasks + 1 + log_output = output.getvalue().strip() + assert "Parallelism ({p}) is greater".format(p=max_num_concurrent_tasks + 1) \ + in log_output - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - assert backend.effective_n_jobs(n_jobs=16) == 16 - assert len(w) == 1 + # Test requested_parallelism exceeds hard cap + with patch_logger("joblib-spark") as output: + parallelism = SparkDistributedBackend._decide_parallelism( + requested_parallelism=SparkDistributedBackend.MAX_CONCURRENT_JOBS_ALLOWED + 1, + spark_default_parallelism=spark_default_parallelism, + max_num_concurrent_tasks=max_num_concurrent_tasks, + ) + assert parallelism == SparkDistributedBackend.MAX_CONCURRENT_JOBS_ALLOWED + log_output = output.getvalue().strip() + assert "SparkTrials.MAX_CONCURRENT_JOBS_ALLOWED ({c})" \ + .format(c=SparkDistributedBackend.MAX_CONCURRENT_JOBS_ALLOWED) in log_output