Skip to content

Commit 4d0c2e6

Browse files
authored
change: Warn if parameter server is used with multi-GPU instance (#1376)
Distributed training with parameter server and multi-GPU instances is not supported. Warn the user that training will not fully leverage all the GPU cores if parameter server is enabled and a multi-GPU instance is selected.
1 parent d8b3012 commit 4d0c2e6

File tree

4 files changed

+72
-0
lines changed

4 files changed

+72
-0
lines changed

src/sagemaker/fw_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Utility methods used by framework classes"""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
import os
1718
import re
1819
import shutil
@@ -23,6 +24,8 @@
2324
from sagemaker import s3
2425
from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN
2526

27+
logger = logging.getLogger("sagemaker")
28+
2629
_TAR_SOURCE_FILENAME = "source.tar.gz"
2730

2831
UploadedCode = namedtuple("UserCode", ["s3_prefix", "script_name"])
@@ -42,6 +45,13 @@
4245
"Python 2. Newer versions of {framework} will only be available for Python 3."
4346
"Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image."
4447
)
48+
PARAMETER_SERVER_MULTI_GPU_WARNING = (
49+
"You have selected a multi-GPU training instance type. "
50+
"You have also enabled parameter server for distributed training. "
51+
"Distributed training with the default parameter server configuration will not "
52+
"fully leverage all GPU cores; the parameter server will be configured to run "
53+
"only one worker per host regardless of the number of GPUs."
54+
)
4555

4656

4757
EMPTY_FRAMEWORK_VERSION_ERROR = (
@@ -68,6 +78,7 @@
6878
DEFAULT_ACCOUNT = "520713654638"
6979
ASIMOV_PROD_ACCOUNT = "763104351884"
7080
ASIMOV_DEFAULT_ACCOUNT = ASIMOV_PROD_ACCOUNT
81+
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
7182

7283
MERGED_FRAMEWORKS_REPO_MAP = {
7384
"tensorflow-scriptmode": "tensorflow-training",
@@ -490,6 +501,44 @@ def empty_framework_version_warning(default_version, latest_version):
490501
return " ".join(msgs)
491502

492503

504+
def warn_if_parameter_server_with_multi_gpu(training_instance_type, distributions):
505+
"""Warn the user that training will not fully leverage all the GPU
506+
cores if parameter server is enabled and a multi-GPU instance is selected.
507+
Distributed training with the default parameter server setup doesn't
508+
support multi-GPU instances.
509+
510+
Args:
511+
training_instance_type (str): A string representing the type of training instance selected.
512+
distributions (dict): A dictionary with information to enable distributed training.
513+
(Defaults to None if distributed training is not enabled.) For example:
514+
515+
.. code:: python
516+
517+
{
518+
'parameter_server':
519+
{
520+
'enabled': True
521+
}
522+
}
523+
524+
525+
"""
526+
if training_instance_type == "local" or distributions is None:
527+
return
528+
529+
is_multi_gpu_instance = (
530+
training_instance_type.split(".")[1].startswith("p")
531+
and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
532+
)
533+
534+
ps_enabled = "parameter_server" in distributions and distributions["parameter_server"].get(
535+
"enabled", False
536+
)
537+
538+
if is_multi_gpu_instance and ps_enabled:
539+
logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING)
540+
541+
493542
def get_unsupported_framework_version_error(
494543
framework_name, unsupported_version, supported_versions
495544
):

src/sagemaker/mxnet/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
empty_framework_version_warning,
2323
python_deprecation_warning,
2424
is_version_equal_or_higher,
25+
warn_if_parameter_server_with_multi_gpu,
2526
)
2627
from sagemaker.mxnet import defaults
2728
from sagemaker.mxnet.model import MXNetModel
@@ -126,6 +127,12 @@ def __init__(
126127
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
127128
)
128129

130+
if distributions is not None:
131+
train_instance_type = kwargs.get("train_instance_type")
132+
warn_if_parameter_server_with_multi_gpu(
133+
training_instance_type=train_instance_type, distributions=distributions
134+
)
135+
129136
self.py_version = py_version
130137
self._configure_distribution(distributions)
131138

src/sagemaker/tensorflow/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ def __init__(
307307
fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
308308
)
309309

310+
if distributions is not None:
311+
train_instance_type = kwargs.get("train_instance_type")
312+
fw.warn_if_parameter_server_with_multi_gpu(
313+
training_instance_type=train_instance_type, distributions=distributions
314+
)
315+
310316
if "enable_sagemaker_metrics" not in kwargs:
311317
# enable sagemaker metrics for TF v1.15 or greater:
312318
if fw.is_version_equal_or_higher([1, 15], self.framework_version):

tests/unit/test_fw_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,3 +1170,13 @@ def test_region_supports_debugger_feature_returns_true_for_supported_regions():
11701170
def test_region_supports_debugger_feature_returns_false_for_unsupported_regions():
11711171
assert fw_utils._region_supports_debugger("us-gov-west-1") is False
11721172
assert fw_utils._region_supports_debugger("us-iso-east-1") is False
1173+
1174+
1175+
def test_warn_if_parameter_server_with_multi_gpu(caplog):
1176+
train_instance_type = "ml.p2.8xlarge"
1177+
distributions = {"parameter_server": {"enabled": True}}
1178+
1179+
fw_utils.warn_if_parameter_server_with_multi_gpu(
1180+
training_instance_type=train_instance_type, distributions=distributions
1181+
)
1182+
assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text

0 commit comments

Comments
 (0)