@@ -35,6 +35,7 @@ def retrieve(
35
35
accelerator_type = None ,
36
36
image_scope = None ,
37
37
container_version = None ,
38
+ distribution = None ,
38
39
):
39
40
"""Retrieves the ECR URI for the Docker image matching the given arguments.
40
41
@@ -54,6 +55,8 @@ def retrieve(
54
55
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
55
56
``image_scope`` is ignored.
56
57
container_version (str): the version of docker image
58
+ distribution (dict): A dictionary with information on how to run distributed training
59
+ (default: None).
57
60
58
61
Returns:
59
62
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -77,10 +80,25 @@ def retrieve(
77
80
processor = _processor (
78
81
instance_type , config .get ("processors" ) or version_config .get ("processors" )
79
82
)
83
+
80
84
tag = _format_tag (
81
- version_config .get ("tag_prefix" , version ), processor , py_version , container_version
85
+ version_config .get ("tag_prefix" , version ),
86
+ processor ,
87
+ py_version ,
88
+ container_version ,
82
89
)
83
90
91
+ if _should_auto_select_container_version (instance_type , distribution ):
92
+ container_versions = {
93
+ "tensorflow-2.3-gpu-py37" : "cu110-ubuntu18.04-v3" ,
94
+ "tensorflow-1.15-gpu-py37" : "cu110-ubuntu18.04-v8" ,
95
+ "mxnet-1.8-gpu-py37" : "cu110-ubuntu16.04-v1" ,
96
+ "pytorch-1.6-gpu-py36" : "cu110-ubuntu18.04-v3" ,
97
+ }
98
+ key = "-" .join ([framework , tag ])
99
+ if key in container_versions :
100
+ tag = "-" .join ([tag , container_versions [key ]])
101
+
84
102
if tag :
85
103
repo += ":{}" .format (tag )
86
104
@@ -217,6 +235,23 @@ def _processor(instance_type, available_processors):
217
235
return processor
218
236
219
237
238
+ def _should_auto_select_container_version (instance_type , distribution ):
239
+ """Returns a boolean that indicates whether to use an auto-selected container version."""
240
+ p4d = False
241
+ if instance_type :
242
+ # looks for either "ml.<family>.<size>" or "ml_<family>"
243
+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
244
+ if match :
245
+ family = match [1 ]
246
+ p4d = family == "p4d"
247
+
248
+ smdistributed = False
249
+ if distribution :
250
+ smdistributed = "smdistributed" in distribution
251
+
252
+ return p4d or smdistributed
253
+
254
+
220
255
def _validate_py_version_and_set_if_needed (py_version , version_config , framework ):
221
256
"""Checks if the Python version is one of the supported versions."""
222
257
if "repository" in version_config :
0 commit comments