Skip to content

Commit 78f0b8b

Browse files
authored
feature: Add mpi support for mxnet estimator api (#1581)
1 parent 7e77981 commit 78f0b8b

File tree

5 files changed

+152
-43
lines changed

5 files changed

+152
-43
lines changed

src/sagemaker/mxnet/estimator.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,44 @@ def __init__(
9999
* ``custom-image:latest``
100100
101101
distributions (dict): A dictionary with information on how to run distributed
102-
training (default: None). To have parameter servers launched for training,
103-
set this value to be ``{'parameter_server': {'enabled': True}}``.
102+
training (default: None). Currently we support distributed training with
103+
parameter server and MPI [Horovod].
104+
To enable parameter server use the following setup:
105+
106+
.. code:: python
107+
108+
{
109+
'parameter_server':
110+
{
111+
'enabled': True
112+
}
113+
}
114+
115+
To enable MPI:
116+
117+
.. code:: python
118+
119+
{
120+
'mpi':
121+
{
122+
'enabled': True
123+
}
124+
}
125+
126+
Option parameters within ``mpi`` are ``processes_per_host``
127+
and ``custom_mpi_options``.
128+
129+
.. code:: python
130+
131+
{
132+
'mpi':
133+
{
134+
'enabled': True,
135+
'processes_per_host': 2,
136+
'custom_mpi_options': '-verbose --NCCL_DEBUG=INFO'
137+
}
138+
}
139+
104140
**kwargs: Additional kwargs passed to the
105141
:class:`~sagemaker.estimator.Framework` constructor.
106142
@@ -159,6 +195,20 @@ def _configure_distribution(self, distributions):
159195
enabled = distributions["parameter_server"].get("enabled", False)
160196
self._hyperparameters[self.LAUNCH_PS_ENV_NAME] = enabled
161197

198+
if "mpi" in distributions:
199+
mpi_dict = distributions["mpi"]
200+
mpi_enabled = mpi_dict.get("enabled", False)
201+
self._hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
202+
203+
if mpi_dict.get("processes_per_host"):
204+
self._hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
205+
"processes_per_host"
206+
)
207+
208+
self._hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
209+
"custom_mpi_options", ""
210+
)
211+
162212
def create_model(
163213
self,
164214
model_server_workers=None,

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ def cpu_instance_type(sagemaker_session, request):
287287
return "ml.m4.xlarge"
288288

289289

290+
@pytest.fixture(scope="module")
291+
def gpu_instance_type(request):
292+
return "ml.p2.xlarge"
293+
294+
290295
@pytest.fixture(scope="session")
291296
def inf_instance_type(sagemaker_session, request):
292297
return "ml.inf1.xlarge"

tests/data/horovod/hvd_mnist_mxnet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import json
2+
import os
3+
4+
import horovod.mxnet as hvd
5+
6+
if __name__ == "__main__":
7+
8+
hvd.init()
9+
10+
with open(os.path.join("/opt/ml/model/rank-%s" % (hvd.rank())), "w+") as f:
11+
basic_info = {"rank": hvd.rank(), "size": hvd.size()}
12+
13+
json.dump(basic_info, f)
14+
print('Saved file "rank-%s": %s' % (hvd.rank(), basic_info))

tests/integ/test_horovod.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import json
1616
import os
1717
import tarfile
18-
from six.moves.urllib.parse import urlparse
1918

2019
import boto3
2120
import pytest
21+
from six.moves.urllib.parse import urlparse
2222

2323
import sagemaker.utils
2424
import tests.integ as integ
@@ -28,11 +28,6 @@
2828
horovod_dir = os.path.join(os.path.dirname(__file__), "..", "data", "horovod")
2929

3030

31-
@pytest.fixture(scope="module")
32-
def gpu_instance_type(request):
33-
return "ml.p2.xlarge"
34-
35-
3631
@pytest.mark.canary_quick
3732
def test_hvd_cpu(sagemaker_session, cpu_instance_type, tmpdir):
3833
_create_and_fit_estimator(sagemaker_session, cpu_instance_type, tmpdir)
@@ -46,41 +41,6 @@ def test_hvd_gpu(sagemaker_session, gpu_instance_type, tmpdir):
4641
_create_and_fit_estimator(sagemaker_session, gpu_instance_type, tmpdir)
4742

4843

49-
@pytest.mark.local_mode
50-
@pytest.mark.parametrize("instances, processes", [[1, 2], (2, 1), (2, 2)])
51-
def test_horovod_local_mode(sagemaker_local_session, instances, processes, tmpdir):
52-
output_path = "file://%s" % tmpdir
53-
job_name = sagemaker.utils.unique_name_from_base("tf-horovod")
54-
estimator = TensorFlow(
55-
entry_point=os.path.join(horovod_dir, "hvd_basic.py"),
56-
role="SageMakerRole",
57-
train_instance_count=2,
58-
train_instance_type="local",
59-
sagemaker_session=sagemaker_local_session,
60-
py_version=integ.PYTHON_VERSION,
61-
script_mode=True,
62-
output_path=output_path,
63-
framework_version="1.12",
64-
distributions={"mpi": {"enabled": True, "processes_per_host": processes}},
65-
)
66-
67-
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
68-
estimator.fit(job_name=job_name)
69-
70-
tmp = str(tmpdir)
71-
extract_files(output_path.replace("file://", ""), tmp)
72-
73-
size = instances * processes
74-
75-
for rank in range(size):
76-
assert read_json("rank-%s" % rank, tmp)["rank"] == rank
77-
78-
79-
def extract_files(output_path, tmpdir):
80-
with tarfile.open(os.path.join(output_path, "model.tar.gz")) as tar:
81-
tar.extractall(tmpdir)
82-
83-
8444
def read_json(file, tmp):
8545
with open(os.path.join(tmp, file)) as f:
8646
return json.load(f)

tests/integ/test_horovod_mx.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import tarfile
18+
19+
import boto3
20+
import pytest
21+
from six.moves.urllib.parse import urlparse
22+
23+
import sagemaker.utils
24+
import tests.integ as integ
25+
from sagemaker.mxnet import MXNet
26+
from tests.integ import timeout
27+
28+
horovod_dir = os.path.join(os.path.dirname(__file__), "..", "data", "horovod")
29+
30+
31+
@pytest.mark.canary_quick
32+
def test_hvd_cpu(sagemaker_session, cpu_instance_type, tmpdir):
33+
_create_and_fit_estimator(sagemaker_session, cpu_instance_type, tmpdir)
34+
35+
36+
@pytest.mark.canary_quick
37+
@pytest.mark.skipif(
38+
integ.test_region() in integ.TRAINING_NO_P2_REGIONS, reason="no ml.p2 instances in this region"
39+
)
40+
def test_hvd_gpu(sagemaker_session, gpu_instance_type, tmpdir):
41+
_create_and_fit_estimator(sagemaker_session, gpu_instance_type, tmpdir)
42+
43+
44+
def read_json(file, tmp):
45+
with open(os.path.join(tmp, file)) as f:
46+
return json.load(f)
47+
48+
49+
def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
50+
parsed_url = urlparse(s3_url)
51+
s3 = boto3.resource("s3", region_name=sagemaker_session.boto_region_name)
52+
53+
model = os.path.join(tmpdir, "model")
54+
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)
55+
56+
with tarfile.open(model, "r") as tar_file:
57+
tar_file.extractall(tmpdir)
58+
59+
60+
def _create_and_fit_estimator(sagemaker_session, instance_type, tmpdir):
61+
job_name = sagemaker.utils.unique_name_from_base("mx-horovod")
62+
estimator = MXNet(
63+
entry_point=os.path.join(horovod_dir, "hvd_mnist_mxnet.py"),
64+
role="SageMakerRole",
65+
train_instance_count=2,
66+
train_instance_type=instance_type,
67+
sagemaker_session=sagemaker_session,
68+
py_version=integ.PYTHON_VERSION,
69+
framework_version="1.6.0",
70+
distributions={"mpi": {"enabled": True}},
71+
)
72+
73+
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
74+
estimator.fit(job_name=job_name)
75+
76+
tmp = str(tmpdir)
77+
extract_files_from_s3(estimator.model_data, tmp, sagemaker_session)
78+
79+
for rank in range(2):
80+
assert read_json("rank-%s" % rank, tmp)["rank"] == rank

0 commit comments

Comments
 (0)