Skip to content

Commit 7e77981

Browse files
authored
fix: account for "py36" and "py37" in image tag parsing (#1737)
1 parent 6b0e009 commit 7e77981

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

src/sagemaker/fw_utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ def _list_files_to_compress(script, directory):
493493
def framework_name_from_image(image_name):
494494
# noinspection LongLine
495495
"""Extract the framework and Python version from the image name.
496+
496497
Args:
497498
image_name (str): Image URI, which should be one of the following forms:
498499
legacy:
@@ -503,25 +504,32 @@ def framework_name_from_image(image_name):
503504
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'
504505
current:
505506
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
507+
current:
508+
'<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
509+
506510
Returns:
507511
tuple: A tuple containing:
508-
str: The framework name str: The Python version str: The image tag
509-
str: If the image is script mode
510-
"""
512+
513+
- str: The framework name
514+
- str: The Python version
515+
- str: The image tag
516+
- str: If the TensorFlow image is script mode
517+
"""
511518
sagemaker_pattern = re.compile(ECR_URI_PATTERN)
512519
sagemaker_match = sagemaker_pattern.match(image_name)
513520
if sagemaker_match is None:
514521
return None, None, None, None
522+
515523
# extract framework, python version and image tag
516524
# We must support both the legacy and current image name format.
517525
name_pattern = re.compile(
518-
r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 # pylint: disable=line-too-long
526+
r"""^(?:sagemaker(?:-rl)?-)?
527+
(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?
528+
(scriptmode|training)?
529+
:(.*)-(.*?)-(py2|py3[67]?)$""",
530+
re.VERBOSE,
519531
)
520-
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
521-
522532
name_match = name_pattern.match(sagemaker_match.group(9))
523-
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
524-
525533
if name_match is not None:
526534
fw, scriptmode, ver, device, py = (
527535
name_match.group(1),
@@ -531,20 +539,25 @@ def framework_name_from_image(image_name):
531539
name_match.group(5),
532540
)
533541
return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode
542+
543+
legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
544+
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
534545
if legacy_match is not None:
535546
return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None)
536547
return None, None, None, None
537548

538549

539550
def framework_version_from_tag(image_tag):
540551
"""Extract the framework version from the image tag.
552+
541553
Args:
542554
image_tag (str): Image tag, which should take the form
543555
'<framework_version>-<device>-<py_version>'
556+
544557
Returns:
545558
str: The framework version.
546559
"""
547-
tag_pattern = re.compile("^(.*)-(cpu|gpu)-(py2|py3)$")
560+
tag_pattern = re.compile("^(.*)-(cpu|gpu)-(py2|py3[67]?)$")
548561
tag_match = tag_pattern.match(image_tag)
549562
return None if tag_match is None else tag_match.group(1)
550563

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def _default_s3_path(self, directory, mpi=False):
728728

729729
def _script_mode_enabled(self):
730730
"""Placeholder docstring"""
731-
return self.py_version == "py3" or self.script_mode
731+
return self.py_version.startswith("py3") or self.script_mode
732732

733733
def _validate_and_set_debugger_configs(self):
734734
"""

tests/unit/test_fw_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,17 @@ def test_framework_name_from_image_rl():
11561156
)
11571157

11581158

1159+
def test_framework_name_from_image_python_versions():
1160+
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.2-cpu-py37"
1161+
assert ("tensorflow", "py37", "2.2-cpu-py37", "training") == fw_utils.framework_name_from_image(
1162+
image_name
1163+
)
1164+
1165+
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.15.2-cpu-py36"
1166+
expected_result = ("tensorflow", "py36", "1.15.2-cpu-py36", "training")
1167+
assert expected_result == fw_utils.framework_name_from_image(image_name)
1168+
1169+
11591170
def test_legacy_name_from_framework_image():
11601171
image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-gpu:2.5.6-gpu-py2"
11611172
framework, py_ver, tag, _ = fw_utils.framework_name_from_image(image_name)
@@ -1200,8 +1211,17 @@ def test_legacy_name_from_image_any_tag():
12001211

12011212

12021213
def test_framework_version_from_tag():
1203-
version = fw_utils.framework_version_from_tag("1.5rc-keras-gpu-py2")
1204-
assert version == "1.5rc-keras"
1214+
tags = (
1215+
"1.5rc-keras-cpu-py2",
1216+
"1.5rc-keras-gpu-py2",
1217+
"1.5rc-keras-cpu-py3",
1218+
"1.5rc-keras-gpu-py36",
1219+
"1.5rc-keras-gpu-py37",
1220+
)
1221+
1222+
for tag in tags:
1223+
version = fw_utils.framework_version_from_tag(tag)
1224+
assert "1.5rc-keras" == version
12051225

12061226

12071227
def test_framework_version_from_tag_other():

tests/unit/test_tf_estimator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,9 @@ def test_legacy_mode_deprecated(sagemaker_session):
11261126

11271127

11281128
def test_script_mode_enabled(sagemaker_session):
1129-
tf = _build_tf(sagemaker_session=sagemaker_session, py_version="py3")
1130-
assert tf._script_mode_enabled() is True
1129+
for py_version in ("py3", "py36", "py37"):
1130+
tf = _build_tf(sagemaker_session=sagemaker_session, py_version=py_version)
1131+
assert tf._script_mode_enabled() is True
11311132

11321133
tf = _build_tf(sagemaker_session=sagemaker_session, script_mode=True)
11331134
assert tf._script_mode_enabled() is True

0 commit comments

Comments
 (0)