@@ -493,6 +493,7 @@ def _list_files_to_compress(script, directory):
493493def framework_name_from_image (image_name ):
494494 # noinspection LongLine
495495 """Extract the framework and Python version from the image name.
496+
496497 Args:
497498 image_name (str): Image URI, which should be one of the following forms:
498499 legacy:
@@ -503,25 +504,32 @@ def framework_name_from_image(image_name):
503504 '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'
504505 current:
505506 '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
507+ current:
508+ '<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
509+
506510 Returns:
507511 tuple: A tuple containing:
508- str: The framework name str: The Python version str: The image tag
509- str: If the image is script mode
510- """
512+
513+ - str: The framework name
514+ - str: The Python version
515+ - str: The image tag
516+ - str: If the TensorFlow image is script mode
517+ """
511518 sagemaker_pattern = re .compile (ECR_URI_PATTERN )
512519 sagemaker_match = sagemaker_pattern .match (image_name )
513520 if sagemaker_match is None :
514521 return None , None , None , None
522+
515523 # extract framework, python version and image tag
516524 # We must support both the legacy and current image name format.
517525 name_pattern = re .compile (
518- r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 # pylint: disable=line-too-long
526+ r"""^(?:sagemaker(?:-rl)?-)?
527+ (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?
528+ (scriptmode|training)?
529+ :(.*)-(.*?)-(py2|py3[67]?)$""" ,
530+ re .VERBOSE ,
519531 )
520- legacy_name_pattern = re .compile (r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$" )
521-
522532 name_match = name_pattern .match (sagemaker_match .group (9 ))
523- legacy_match = legacy_name_pattern .match (sagemaker_match .group (9 ))
524-
525533 if name_match is not None :
526534 fw , scriptmode , ver , device , py = (
527535 name_match .group (1 ),
@@ -531,20 +539,25 @@ def framework_name_from_image(image_name):
531539 name_match .group (5 ),
532540 )
533541 return fw , py , "{}-{}-{}" .format (ver , device , py ), scriptmode
542+
543+ legacy_name_pattern = re .compile (r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$" )
544+ legacy_match = legacy_name_pattern .match (sagemaker_match .group (9 ))
534545 if legacy_match is not None :
535546 return (legacy_match .group (1 ), legacy_match .group (2 ), legacy_match .group (4 ), None )
536547 return None , None , None , None
537548
538549
539550def framework_version_from_tag (image_tag ):
540551 """Extract the framework version from the image tag.
552+
541553 Args:
542554 image_tag (str): Image tag, which should take the form
543555 '<framework_version>-<device>-<py_version>'
556+
544557 Returns:
545558 str: The framework version.
546559 """
547- tag_pattern = re .compile ("^(.*)-(cpu|gpu)-(py2|py3)$" )
560+ tag_pattern = re .compile ("^(.*)-(cpu|gpu)-(py2|py3[67]? )$" )
548561 tag_match = tag_pattern .match (image_tag )
549562 return None if tag_match is None else tag_match .group (1 )
550563
0 commit comments