Skip to content

Commit ff03eae

Browse files
ajaykarpurChoiByungWook
authored andcommitted
feat: auto-select container version for p4d and smdistributed (#517)
1 parent fe85356 commit ff03eae

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,13 +2297,18 @@ def training_image_uri(self):
22972297
"""
22982298
if self.image_uri:
22992299
return self.image_uri
2300+
if hasattr(self, "distribution"):
2301+
distribution = self.distribution # pylint: disable=no-member
2302+
else:
2303+
distribution = None
23002304
return image_uris.retrieve(
23012305
self._framework_name,
23022306
self.sagemaker_session.boto_region_name,
23032307
instance_type=self.instance_type,
23042308
version=self.framework_version, # pylint: disable=no-member
23052309
py_version=self.py_version, # pylint: disable=no-member
23062310
image_scope="training",
2311+
distribution=distribution,
23072312
)
23082313

23092314
@classmethod

src/sagemaker/image_uris.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def retrieve(
3535
accelerator_type=None,
3636
image_scope=None,
3737
container_version=None,
38+
distribution=None,
3839
):
3940
"""Retrieves the ECR URI for the Docker image matching the given arguments.
4041
@@ -54,6 +55,8 @@ def retrieve(
5455
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
5556
``image_scope`` is ignored.
5657
container_version (str): the version of docker image
58+
distribution (dict): A dictionary with information on how to run distributed training
59+
(default: None).
5760
5861
Returns:
5962
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -77,10 +80,25 @@ def retrieve(
7780
processor = _processor(
7881
instance_type, config.get("processors") or version_config.get("processors")
7982
)
83+
8084
tag = _format_tag(
81-
version_config.get("tag_prefix", version), processor, py_version, container_version
85+
version_config.get("tag_prefix", version),
86+
processor,
87+
py_version,
88+
container_version,
8289
)
8390

91+
if _should_auto_select_container_version(instance_type, distribution):
92+
container_versions = {
93+
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
94+
"tensorflow-1.15-gpu-py37": "cu110-ubuntu18.04-v8",
95+
"mxnet-1.8-gpu-py37": "cu110-ubuntu16.04-v1",
96+
"pytorch-1.6-gpu-py36": "cu110-ubuntu18.04-v3",
97+
}
98+
key = "-".join([framework, tag])
99+
if key in container_versions:
100+
tag = "-".join([tag, container_versions[key]])
101+
84102
if tag:
85103
repo += ":{}".format(tag)
86104

@@ -217,6 +235,23 @@ def _processor(instance_type, available_processors):
217235
return processor
218236

219237

238+
def _should_auto_select_container_version(instance_type, distribution):
239+
"""Returns a boolean that indicates whether to use an auto-selected container version."""
240+
p4d = False
241+
if instance_type:
242+
# looks for either "ml.<family>.<size>" or "ml_<family>"
243+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
244+
if match:
245+
family = match[1]
246+
p4d = family == "p4d"
247+
248+
smdistributed = False
249+
if distribution:
250+
smdistributed = "smdistributed" in distribution
251+
252+
return p4d or smdistributed
253+
254+
220255
def _validate_py_version_and_set_if_needed(py_version, version_config, framework):
221256
"""Checks if the Python version is one of the supported versions."""
222257
if "repository" in version_config:

tests/unit/sagemaker/image_uris/test_retrieve.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,21 @@ def test_retrieve_default_processor_type_if_possible(config_for_framework):
538538
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri
539539

540540

541+
def test_retrieve_auto_selected_container_version():
542+
uri = image_uris.retrieve(
543+
framework="tensorflow",
544+
region="us-west-2",
545+
version="2.3",
546+
py_version="py37",
547+
instance_type="ml.p4d.24xlarge",
548+
image_scope="training",
549+
)
550+
assert (
551+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.3-gpu-py37-cu110-ubuntu18.04-v3"
552+
== uri
553+
)
554+
555+
541556
@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
542557
def test_retrieve_unsupported_processor_type(config_for_framework):
543558
with pytest.raises(ValueError) as e:

0 commit comments

Comments
 (0)