Skip to content

Commit dfa3c07

Browse files
MrBagomengxr
authored andcommitted
[ML-6097] Update Horovod Runner to include return value and multi-gpu info. (#183)
* [ML-6012] add return value in oss, update document and add test (#138) In this PR, we: change the horovodrunner.run api to return values in horovod/runner_base.py update the documentation to reflect that change add a test to runner_base_test.py to test the api change * Update HorovodRunner docs to include multi-gpu support. * Revert "Update HorovodRunner docs to include multi-gpu support." This reverts commit 0172bbe. * update np doc (#136)
1 parent 7aff76e commit dfa3c07

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

python/sparkdl/horovod/runner_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(self, np):
4343
:param np: number of parallel processes to use for the Horovod job.
4444
This argument only takes effect on Databricks Runtime 5.0 ML and above.
4545
It is ignored in the open-source version.
46+
On Databricks, each process will take an available task slot,
47+
which maps to a GPU on a GPU cluster or a CPU core on a CPU cluster.
4648
Accepted values are:
4749
4850
- If -1, this will spawn a subprocess on the driver node to run the Horovod job locally.
@@ -79,12 +81,14 @@ def run(self, main, **kwargs):
7981
Avoid referencing large objects in the function, which might result large pickled data,
8082
making the job slow to start.
8183
:param kwargs: keyword arguments passed to the main function at invocation time.
82-
:return: None
84+
:return: return value of the main function.
85+
With `np>=0`, this returns the value from the rank 0 process. Note that the returned
86+
value should be serializable using cloudpickle.
8387
"""
8488
logger = logging.getLogger("HorovodRunner")
8589
logger.warning(
8690
"You are running the open-source version of HorovodRunner. "
8791
"It only does basic checks and invokes the main function, "
8892
"which is for local development only. "
8993
"Please use Databricks Runtime ML 5.0+ to distribute the job.")
90-
main(**kwargs)
94+
return main(**kwargs)

python/tests/horovod/runner_base_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,10 @@ def append(value):
5151

5252
hr.run(append, value=1)
5353
self.assertEquals(data[0], 1)
54+
55+
def test_return_value(self):
56+
"""Test that the return value is returned to the user."""
57+
hr = HorovodRunner(np=-1)
58+
return_value = hr.run(lambda: 42)
59+
self.assertEquals(return_value, 42)
60+

0 commit comments

Comments
 (0)