Skip to content

Commit f65fee1

Browse files
authored
[ML-5537] HorovodRunner public API and basic test (#117)
1 parent 4daa117 commit f65fee1

File tree

5 files changed

+154
-12
lines changed

5 files changed

+154
-12
lines changed

python/sparkdl/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
from .transformers.tf_tensor import TFTransformer
2121
from .transformers.utils import imageInputPlaceholder
2222
from .estimators.keras_image_file_estimator import KerasImageFileEstimator
23+
from .horovod.runner_base import HorovodRunnerBase as HorovodRunner
2324

2425
__all__ = [
2526
'TFImageTransformer', 'TFInputGraph', 'TFTransformer', 'DeepImagePredictor',
2627
'DeepImageFeaturizer', 'KerasImageFileTransformer', 'KerasTransformer',
27-
'imageInputPlaceholder', 'KerasImageFileEstimator']
28+
'imageInputPlaceholder', 'KerasImageFileEstimator',
29+
'HorovodRunner'
30+
]

python/sparkdl/horovod/__init__.py

Whitespace-only changes.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2017 Databricks, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
# pylint: disable=no-self-use
17+
# pylint: disable=too-few-public-methods
18+
19+
from __future__ import absolute_import, division, print_function
20+
21+
import logging
22+
23+
from sparkdl.param import keyword_only
24+
25+
class HorovodRunnerBase(object):
26+
"""
27+
HorovodRunner runs distributed deep learning training jobs using Horovod.
28+
29+
On Databricks Runtime for Machine Learning 5.0+, it launches the Horovod job as a distributed
30+
Spark job. It makes running Horovod easy on Databricks by managing the cluster setup and
31+
integrating with Spark. Check out Databricks documentation to view end-to-end examples and
32+
performance tuning tips.
33+
34+
The open-source version only runs the job locally inside the same Python process,
35+
which is for local development only.
36+
37+
.. note:: Horovod is a distributed training framework developed by Uber.
38+
"""
39+
40+
@keyword_only
41+
def __init__(self, np):
42+
"""
43+
:param np: number of parallel processes to use for the Horovod job.
44+
This argument only takes effect on Databricks Runtime for Machine Learning 5.0+.
45+
It is ignored in the open-source version.
46+
Accepted values are:
47+
48+
- If -1, this will spawn a subprocess on the driver node to run the Horovod job locally.
49+
Training stdout and stderr messages go to the notebook cell output.
50+
This is useful for debugging and we recommend testing your code under this mode first.
51+
However, be careful of heavy use of the Spark driver on a shared Databricks cluster.
52+
- If >0, this will launch a Spark job with `np` tasks starting all together and run the
53+
Horovod job on the task nodes.
54+
It will wait until `np` task slots are available to launch the job.
55+
If `np` is greater than the total number of task slots on the cluster,
56+
the job will fail.
57+
Training stdout and stderr messages are redirected to the stderr stream of the first
58+
task, which you can find in the Spark UI.
59+
- If 0, this will use all task slots on the cluster to launch the job.
60+
"""
61+
self.num_processor = np
62+
if self.num_processor < -1:
63+
raise ValueError("Invalid number of processes: np = %s" % str(self.num_processor))
64+
65+
def run(self, main, **kwargs):
66+
"""
67+
Runs a Horovod training job invoking main(**kwargs).
68+
69+
The open-source version only invokes main(**kwargs) inside the same Python process.
70+
On Databricks Runtime for Machine Learning 5.0+, it will launch the Horovod job based on the
71+
documented behavior of `np`. Both the main function and the keyword arguments are
72+
serialized using cloudpickle and distributed to cluster workers.
73+
74+
:param main: a Python function that contains the Horovod training code.
75+
The expected signature is `def main(**kwargs)` or compatible forms.
76+
Because the function gets pickled and distributed to workers,
77+
please change global states inside the function, e.g., setting logging level,
78+
and be aware of pickling limitations.
79+
Avoid referencing large objects in the function, which might result large pickled data,
80+
making the job slow to start.
81+
:param kwargs: keyword arguments passed to the main function at invocation time.
82+
:return: None
83+
"""
84+
logger = logging.getLogger("HorovodRunner")
85+
logger.warning(
86+
"You are running the open-source version of HorovodRunner. "
87+
"It only does basic checks and invokes the main function, "
88+
"which is for local development only. "
89+
"Please use Databricks Runtime ML 5.0+ to distribute the job.")
90+
main(**kwargs)

python/sparkdl/param/shared_params.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
to pyspark.ml.param. The copy is due to some useful pyspark fns/classes being
1818
private APIs.
1919
"""
20-
from functools import wraps
2120
import textwrap
2221

2322
import keras.backend as K
2423
from keras.models import load_model
24+
import wrapt
2525

2626
from pyspark.ml.param import Param, Params, TypeConverters
2727
import sparkdl.graph.utils as tfx
@@ -36,22 +36,18 @@
3636
########################################################
3737

3838

39-
def keyword_only(func):
39+
@wrapt.decorator
40+
def keyword_only(func, self, args, kwargs):
4041
"""
4142
A decorator that forces keyword arguments in the wrapped method
4243
and saves actual input keyword arguments in `_input_kwargs`.
4344
4445
.. note:: Should only be used to wrap a method where first arg is `self`
4546
"""
46-
47-
@wraps(func)
48-
def wrapper(self, *args, **kwargs):
49-
if len(args) > 0:
50-
raise TypeError("Method %s forces keyword arguments." % func.__name__)
51-
self._input_kwargs = kwargs
52-
return func(self, **kwargs)
53-
54-
return wrapper
47+
if len(args) > 0:
48+
raise TypeError("Method %s forces keyword arguments." % func.__name__)
49+
self._input_kwargs = kwargs
50+
return func(**kwargs)
5551

5652

5753
class HasInputCol(Params):
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2017 Databricks, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
from __future__ import absolute_import, division, print_function
17+
18+
import inspect
19+
import unittest
20+
21+
from sparkdl.horovod.runner_base import HorovodRunnerBase as HorovodRunner
22+
23+
24+
class HorovodRunnerBaseTestCase(unittest.TestCase):
25+
26+
def test_func_signature(self):
27+
"""Test that __init__ and run signatures are correct."""
28+
init_spec = inspect.getargspec(HorovodRunner.__init__) # pylint: disable=deprecated-method
29+
self.assertEquals(init_spec.args, ["self", "np"])
30+
self.assertIsNone(init_spec.varargs)
31+
self.assertIsNone(init_spec.keywords)
32+
self.assertIsNone(init_spec.defaults)
33+
run_spec = inspect.getargspec(HorovodRunner.run) # pylint: disable=deprecated-method
34+
self.assertEquals(run_spec.args, ["self", "main"])
35+
self.assertIsNone(run_spec.varargs)
36+
self.assertEquals(run_spec.keywords, "kwargs")
37+
self.assertIsNone(run_spec.defaults)
38+
39+
def test_init_keyword_only(self):
40+
"""Test that user must use keyword args in __init__"""
41+
with self.assertRaises(TypeError):
42+
HorovodRunner(2)
43+
44+
def test_run(self):
45+
"""Test that run just invokes the main method in the same process."""
46+
hr = HorovodRunner(np=-1)
47+
data = []
48+
49+
def append(value):
50+
data.append(value)
51+
52+
hr.run(append, value=1)
53+
self.assertEquals(data[0], 1)

0 commit comments

Comments
 (0)