|
| 1 | +# Copyright 2017 Databricks, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | + |
| 16 | +# pylint: disable=no-self-use |
| 17 | +# pylint: disable=too-few-public-methods |
| 18 | + |
| 19 | +from __future__ import absolute_import, division, print_function |
| 20 | + |
| 21 | +import logging |
| 22 | + |
| 23 | +from sparkdl.param import keyword_only |
| 24 | + |
| 25 | +class HorovodRunnerBase(object): |
| 26 | + """ |
| 27 | + HorovodRunner runs distributed deep learning training jobs using Horovod. |
| 28 | +
|
| 29 | + On Databricks Runtime for Machine Learning 5.0+, it launches the Horovod job as a distributed |
| 30 | + Spark job. It makes running Horovod easy on Databricks by managing the cluster setup and |
| 31 | + integrating with Spark. Check out Databricks documentation to view end-to-end examples and |
| 32 | + performance tuning tips. |
| 33 | +
|
| 34 | + The open-source version only runs the job locally inside the same Python process, |
| 35 | + which is for local development only. |
| 36 | +
|
| 37 | + .. note:: Horovod is a distributed training framework developed by Uber. |
| 38 | + """ |
| 39 | + |
| 40 | + @keyword_only |
| 41 | + def __init__(self, np): |
| 42 | + """ |
| 43 | + :param np: number of parallel processes to use for the Horovod job. |
| 44 | + This argument only takes effect on Databricks Runtime for Machine Learning 5.0+. |
| 45 | + It is ignored in the open-source version. |
| 46 | + Accepted values are: |
| 47 | +
|
| 48 | + - If -1, this will spawn a subprocess on the driver node to run the Horovod job locally. |
| 49 | + Training stdout and stderr messages go to the notebook cell output. |
| 50 | + This is useful for debugging and we recommend testing your code under this mode first. |
| 51 | + However, be careful of heavy use of the Spark driver on a shared Databricks cluster. |
| 52 | + - If >0, this will launch a Spark job with `np` tasks starting all together and run the |
| 53 | + Horovod job on the task nodes. |
| 54 | + It will wait until `np` task slots are available to launch the job. |
| 55 | + If `np` is greater than the total number of task slots on the cluster, |
| 56 | + the job will fail. |
| 57 | + Training stdout and stderr messages are redirected to the stderr stream of the first |
| 58 | + task, which you can find in the Spark UI. |
| 59 | + - If 0, this will use all task slots on the cluster to launch the job. |
| 60 | + """ |
| 61 | + self.num_processor = np |
| 62 | + if self.num_processor < -1: |
| 63 | + raise ValueError("Invalid number of processes: np = %s" % str(self.num_processor)) |
| 64 | + |
| 65 | + def run(self, main, **kwargs): |
| 66 | + """ |
| 67 | + Runs a Horovod training job invoking main(**kwargs). |
| 68 | +
|
| 69 | + The open-source version only invokes main(**kwargs) inside the same Python process. |
| 70 | + On Databricks Runtime for Machine Learning 5.0+, it will launch the Horovod job based on the |
| 71 | + documented behavior of `np`. Both the main function and the keyword arguments are |
| 72 | + serialized using cloudpickle and distributed to cluster workers. |
| 73 | +
|
| 74 | + :param main: a Python function that contains the Horovod training code. |
| 75 | + The expected signature is `def main(**kwargs)` or compatible forms. |
| 76 | + Because the function gets pickled and distributed to workers, |
| 77 | + please change global states inside the function, e.g., setting logging level, |
| 78 | + and be aware of pickling limitations. |
| 79 | + Avoid referencing large objects in the function, which might result large pickled data, |
| 80 | + making the job slow to start. |
| 81 | + :param kwargs: keyword arguments passed to the main function at invocation time. |
| 82 | + :return: None |
| 83 | + """ |
| 84 | + logger = logging.getLogger("HorovodRunner") |
| 85 | + logger.warning( |
| 86 | + "You are running the open-source version of HorovodRunner. " |
| 87 | + "It only does basic checks and invokes the main function, " |
| 88 | + "which is for local development only. " |
| 89 | + "Please use Databricks Runtime ML 5.0+ to distribute the job.") |
| 90 | + main(**kwargs) |
0 commit comments