|
13 | 13 | """Utility methods used by framework classes"""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
| 16 | +import logging |
16 | 17 | import os
|
17 | 18 | import re
|
18 | 19 | import shutil
|
|
23 | 24 | from sagemaker import s3
|
24 | 25 | from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN
|
25 | 26 |
|
| 27 | +logger = logging.getLogger("sagemaker") |
| 28 | + |
26 | 29 | _TAR_SOURCE_FILENAME = "source.tar.gz"
|
27 | 30 |
|
28 | 31 | UploadedCode = namedtuple("UserCode", ["s3_prefix", "script_name"])
|
|
42 | 45 | "Python 2. Newer versions of {framework} will only be available for Python 3."
|
43 | 46 | "Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image."
|
44 | 47 | )
|
| 48 | +PARAMETER_SERVER_MULTI_GPU_WARNING = ( |
| 49 | + "You have selected a multi-GPU training instance type. " |
| 50 | + "You have also enabled parameter server for distributed training. " |
| 51 | + "Distributed training with the default parameter server configuration will not " |
| 52 | + "fully leverage all GPU cores; the parameter server will be configured to run " |
| 53 | + "only one worker per host regardless of the number of GPUs." |
| 54 | +) |
45 | 55 |
|
46 | 56 |
|
47 | 57 | EMPTY_FRAMEWORK_VERSION_ERROR = (
|
|
68 | 78 | DEFAULT_ACCOUNT = "520713654638"
|
69 | 79 | ASIMOV_PROD_ACCOUNT = "763104351884"
|
70 | 80 | ASIMOV_DEFAULT_ACCOUNT = ASIMOV_PROD_ACCOUNT
|
| 81 | +SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge") |
71 | 82 |
|
72 | 83 | MERGED_FRAMEWORKS_REPO_MAP = {
|
73 | 84 | "tensorflow-scriptmode": "tensorflow-training",
|
@@ -490,6 +501,44 @@ def empty_framework_version_warning(default_version, latest_version):
|
490 | 501 | return " ".join(msgs)
|
491 | 502 |
|
492 | 503 |
|
| 504 | +def warn_if_parameter_server_with_multi_gpu(training_instance_type, distributions): |
| 505 | + """Warn the user that training will not fully leverage all the GPU |
| 506 | + cores if parameter server is enabled and a multi-GPU instance is selected. |
| 507 | + Distributed training with the default parameter server setup doesn't |
| 508 | + support multi-GPU instances. |
| 509 | +
|
| 510 | + Args: |
| 511 | + training_instance_type (str): A string representing the type of training instance selected. |
| 512 | + distributions (dict): A dictionary with information to enable distributed training. |
| 513 | + (Defaults to None if distributed training is not enabled.) For example: |
| 514 | +
|
| 515 | + .. code:: python |
| 516 | +
|
| 517 | + { |
| 518 | + 'parameter_server': |
| 519 | + { |
| 520 | + 'enabled': True |
| 521 | + } |
| 522 | + } |
| 523 | +
|
| 524 | +
|
| 525 | + """ |
| 526 | + if training_instance_type == "local" or distributions is None: |
| 527 | + return |
| 528 | + |
| 529 | + is_multi_gpu_instance = ( |
| 530 | + training_instance_type.split(".")[1].startswith("p") |
| 531 | + and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES |
| 532 | + ) |
| 533 | + |
| 534 | + ps_enabled = "parameter_server" in distributions and distributions["parameter_server"].get( |
| 535 | + "enabled", False |
| 536 | + ) |
| 537 | + |
| 538 | + if is_multi_gpu_instance and ps_enabled: |
| 539 | + logger.warning(PARAMETER_SERVER_MULTI_GPU_WARNING) |
| 540 | + |
| 541 | + |
493 | 542 | def get_unsupported_framework_version_error(
|
494 | 543 | framework_name, unsupported_version, supported_versions
|
495 | 544 | ):
|
|
0 commit comments