Skip to content

Commit bbb3f16

Browse files
akhilmehraChoiByungWook
authored andcommitted
feature: add model parallelism support (#441)
1 parent c8104af commit bbb3f16

File tree

5 files changed

+233
-8
lines changed

5 files changed

+233
-8
lines changed

src/sagemaker/fw_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,96 @@ def validate_source_dir(script, directory):
7676
return True
7777

7878

79+
def get_mp_parameters(distribution):
80+
"""Get the model parallelism parameters provided by the user
81+
82+
Args:
83+
distribution: distribution dictionary defined by the user
84+
85+
Returns:
86+
params: dictionary containing model parallelism parameters
87+
to be used for training
88+
"""
89+
try:
90+
mp_dict = distribution["smdistributed"]["modelparallel"]
91+
except KeyError:
92+
mp_dict = {}
93+
if mp_dict.get("enabled", False) is True:
94+
params = mp_dict.get("parameters", {})
95+
validate_mp_config(params)
96+
return params
97+
return None
98+
99+
100+
def validate_mp_config(config):
101+
"""Validate the configuration dictionary for model parallelism.
102+
103+
Args:
104+
config (dict): Dictionary holding configuration keys and values.
105+
106+
Raises:
107+
ValueError: If any of the keys have incorrect values.
108+
"""
109+
110+
if "partitions" not in config:
111+
raise ValueError("'partitions' is a required parameter.")
112+
113+
def validate_positive(key):
114+
try:
115+
if not isinstance(config[key], int) or config[key] < 1:
116+
raise ValueError(f"The number of {key} must be a positive integer.")
117+
except KeyError:
118+
pass
119+
120+
def validate_in(key, vals):
121+
try:
122+
if config[key] not in vals:
123+
raise ValueError(f"{key} must be a value in: {vals}.")
124+
except KeyError:
125+
pass
126+
127+
def validate_bool(keys):
128+
validate_in(keys, [True, False])
129+
130+
validate_in("pipeline", ["simple", "interleaved", "_only_forward"])
131+
validate_in("placement_strategy", ["spread", "cluster"])
132+
validate_in("optimize", ["speed", "memory"])
133+
134+
for key in ["microbatches", "partitions"]:
135+
validate_positive(key)
136+
137+
for key in ["auto_partition", "contiguous", "load_partition", "horovod", "ddp"]:
138+
validate_bool(key)
139+
140+
if "partition_file" in config and not isinstance(config.get("partition_file"), str):
141+
raise ValueError("'partition_file' must be a str.")
142+
143+
if config.get("auto_partition") is False and "default_partition" not in config:
144+
raise ValueError("default_partition must be supplied if auto_partition is set to False!")
145+
146+
if "default_partition" in config and config["default_partition"] >= config["partitions"]:
147+
raise ValueError("default_partition must be less than the number of partitions!")
148+
149+
if "memory_weight" in config and (
150+
config["memory_weight"] > 1.0 or config["memory_weight"] < 0.0
151+
):
152+
raise ValueError("memory_weight must be between 0.0 and 1.0!")
153+
154+
if "ddp_port" in config and "ddp" not in config:
155+
raise ValueError("`ddp_port` needs `ddp` to be set as well")
156+
157+
if "ddp_dist_backend" in config and "ddp" not in config:
158+
raise ValueError("`ddp_dist_backend` needs `ddp` to be set as well")
159+
160+
if "ddp_port" in config:
161+
if not isinstance(config["ddp_port"], int) or config["ddp_port"] < 0:
162+
value = config["ddp_port"]
163+
raise ValueError(f"Invalid port number {value}.")
164+
165+
if config.get("horovod", False) and config.get("ddp", False):
166+
raise ValueError("'ddp' and 'horovod' cannot be simultaneously enabled.")
167+
168+
79169
def tar_and_upload_dir(
80170
session,
81171
bucket,

src/sagemaker/pytorch/estimator.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
framework_version_from_tag,
2525
python_deprecation_warning,
2626
validate_version_or_image_args,
27+
warn_if_parameter_server_with_multi_gpu,
2728
validate_smdistributed,
29+
get_mp_parameters,
2830
)
2931
from sagemaker.pytorch import defaults
3032
from sagemaker.pytorch.model import PyTorchModel
@@ -93,7 +95,6 @@ def __init__(
9395
for training and hosting, instead of selecting the appropriate
9496
SageMaker official image based on framework_version and
9597
py_version. It can be an ECR url or dockerhub image and tag.
96-
9798
Examples:
9899
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
99100
* ``custom-image:latest``
@@ -102,17 +103,41 @@ def __init__(
102103
``image_uri`` is required. If also ``None``, then a ``ValueError``
103104
will be raised.
104105
distribution (dict): A dictionary with information on how to run distributed training
105-
(default: None). Currently we support distributed training with SMDistributed
106-
Data Parallel strategy.
106+
(default: None). Currently we support distributed training with parameter servers,
107+
Model Parallelism, Data Parallelism, and MPI. Model Parallelism can only be used
108+
with MPI.
109+
To enable parameter server use the following setup:
110+
111+
.. code:: python
112+
113+
{
114+
"parameter_server": {
115+
"enabled": True
116+
}
117+
}
118+
119+
To enable MPI:
120+
121+
.. code:: python
122+
123+
{
124+
"mpi": {
125+
"enabled": True
126+
}
127+
}
107128
108-
To enable SMDistributed Data Parallel:
129+
To enable SMDistributed Data Parallel or Model Parallel:
109130
110131
.. code:: python
111132
112133
{
113134
"smdistributed": {
114135
"dataparallel": {
115136
"enabled": True
137+
},
138+
"modelparallel": {
139+
"enabled": True,
140+
"parameters": {}
116141
}
117142
}
118143
}
@@ -148,6 +173,10 @@ def __init__(
148173
image_uri=image_uri,
149174
)
150175

176+
warn_if_parameter_server_with_multi_gpu(
177+
training_instance_type=instance_type, distribution=distribution
178+
)
179+
151180
if "enable_sagemaker_metrics" not in kwargs:
152181
# enable sagemaker metrics for PT v1.3 or greater:
153182
if self.framework_version and Version(self.framework_version) >= Version("1.3"):
@@ -163,6 +192,30 @@ def hyperparameters(self):
163192
hyperparameters = super(PyTorch, self).hyperparameters()
164193
additional_hyperparameters = {}
165194

195+
if "parameter_server" in self.distribution:
196+
ps_enabled = self.distribution.get("parameter_server").get("enabled", False)
197+
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
198+
199+
if "mpi" in self.distribution:
200+
mpi_dict = self.distribution["mpi"]
201+
mpi_enabled = mpi_dict.get("enabled", False)
202+
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
203+
204+
if mpi_dict.get("processes_per_host"):
205+
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
206+
"processes_per_host"
207+
)
208+
209+
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
210+
"custom_mpi_options", ""
211+
)
212+
213+
if get_mp_parameters(self.distribution):
214+
additional_hyperparameters["mp_parameters"] = get_mp_parameters(self.distribution)
215+
216+
elif "modelparallel" in self.distribution.get("smdistributed", {}):
217+
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
218+
166219
if "smdistributed" in self.distribution:
167220
# smdistributed strategy selected
168221
smdistributed = self.distribution["smdistributed"]

src/sagemaker/tensorflow/estimator.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def __init__(
8181
``image_uri`` is required. If also ``None``, then a ``ValueError``
8282
will be raised.
8383
distribution (dict): A dictionary with information on how to run distributed training
84-
(default: None). Currently we support distributed training with parameter servers
85-
and MPI.
84+
(default: None). Currently we support distributed training with parameter servers,
85+
Model Parallelism, Data Parallelism, and MPI. Model Parallelism can only be used
86+
with MPI.
8687
To enable parameter server use the following setup:
8788
8889
.. code:: python
@@ -103,14 +104,18 @@ def __init__(
103104
}
104105
}
105106
106-
To enable SMDistributed Data Parallel:
107+
To enable SMDistributed Data Parallel or Model Parallel:
107108
108109
.. code:: python
109110
110111
{
111112
"smdistributed": {
112113
"dataparallel": {
113114
"enabled": True
115+
},
116+
"modelparallel": {
117+
"enabled": True,
118+
"parameters": {}
114119
}
115120
}
116121
}
@@ -335,6 +340,14 @@ def hyperparameters(self):
335340
"custom_mpi_options", ""
336341
)
337342

343+
if fw.get_mp_parameters(self.distribution):
344+
additional_hyperparameters["mp_parameters"] = fw.get_mp_parameters(
345+
self.distribution
346+
)
347+
348+
elif "modelparallel" in self.distribution.get("smdistributed", {}):
349+
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
350+
338351
if "smdistributed" in self.distribution:
339352
# smdistributed strategy selected
340353
smdistributed = self.distribution["smdistributed"]

tests/integ/test_pytorch.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@
3636
EIA_SCRIPT = os.path.join(EIA_DIR, "empty_inference_script.py")
3737

3838

39+
@pytest.fixture(scope="module", name="pytorch_mpi_training_job")
40+
def fixture_mpi_training_job(
41+
sagemaker_session,
42+
pytorch_training_latest_version,
43+
pytorch_training_latest_py_version,
44+
cpu_instance_type,
45+
):
46+
47+
distribution_dict = {"mpi": {"enabled": True}}
48+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
49+
pytorch = _get_pytorch_estimator(
50+
sagemaker_session,
51+
pytorch_training_latest_version,
52+
pytorch_training_latest_py_version,
53+
cpu_instance_type,
54+
distributions_dict=distribution_dict,
55+
)
56+
57+
pytorch.fit({"training": _upload_training_data(pytorch)})
58+
return pytorch.latest_training_job.name
59+
60+
3961
@pytest.fixture(scope="module", name="pytorch_training_job")
4062
def fixture_training_job(
4163
sagemaker_session,
@@ -220,7 +242,12 @@ def _upload_training_data(pytorch):
220242

221243

222244
def _get_pytorch_estimator(
223-
sagemaker_session, pytorch_version, py_version, instance_type, entry_point=MNIST_SCRIPT
245+
sagemaker_session,
246+
pytorch_version,
247+
py_version,
248+
instance_type,
249+
entry_point=MNIST_SCRIPT,
250+
distributions_dict={},
224251
):
225252
return PyTorch(
226253
entry_point=entry_point,
@@ -230,6 +257,7 @@ def _get_pytorch_estimator(
230257
instance_count=1,
231258
instance_type=instance_type,
232259
sagemaker_session=sagemaker_session,
260+
distributions=distributions_dict,
233261
)
234262

235263

tests/unit/test_fw_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from itertools import product
2020

2121
import pytest
22+
2223
from mock import Mock, patch
2324

2425
from sagemaker import fw_utils
@@ -92,6 +93,46 @@ def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
9293
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
9394

9495

96+
def test_mp_config_partition_exists():
97+
mp_parameters = {}
98+
with pytest.raises(ValueError):
99+
fw_utils.validate_mp_config(mp_parameters)
100+
101+
102+
@pytest.mark.parametrize(
103+
"pipeline, placement_strategy, optimize, trace_device",
104+
[
105+
("simple", "spread", "speed", "cpu"),
106+
("interleaved", "cluster", "memory", "gpu"),
107+
("_only_forward", "spread", "speed", "gpu"),
108+
],
109+
)
110+
def test_mp_config_string_names(pipeline, placement_strategy, optimize, trace_device):
111+
mp_parameters = {
112+
"partitions": 2,
113+
"pipeline": pipeline,
114+
"placement_strategy": placement_strategy,
115+
"optimize": optimize,
116+
"trace_device": trace_device,
117+
}
118+
fw_utils.validate_mp_config(mp_parameters)
119+
120+
121+
def test_mp_config_auto_partition_arg():
122+
mp_parameters = {}
123+
mp_parameters["partitions"] = 2
124+
mp_parameters["auto_partition"] = False
125+
with pytest.raises(ValueError):
126+
fw_utils.validate_mp_config(mp_parameters)
127+
128+
mp_parameters["default_partition"] = 1
129+
fw_utils.validate_mp_config(mp_parameters)
130+
131+
mp_parameters["default_partition"] = 4
132+
with pytest.raises(ValueError):
133+
fw_utils.validate_mp_config(mp_parameters)
134+
135+
95136
def test_validate_source_dir_does_not_exits(sagemaker_session):
96137
script = "mnist.py"
97138
directory = " !@#$%^&*()path probably in not there.!@#$%^&*()"

0 commit comments

Comments
 (0)