@@ -493,6 +493,7 @@ def _list_files_to_compress(script, directory):
493
493
def framework_name_from_image (image_name ):
494
494
# noinspection LongLine
495
495
"""Extract the framework and Python version from the image name.
496
+
496
497
Args:
497
498
image_name (str): Image URI, which should be one of the following forms:
498
499
legacy:
@@ -503,25 +504,32 @@ def framework_name_from_image(image_name):
503
504
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'
504
505
current:
505
506
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
507
+ current:
508
+ '<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
509
+
506
510
Returns:
507
511
tuple: A tuple containing:
508
- str: The framework name str: The Python version str: The image tag
509
- str: If the image is script mode
510
- """
512
+
513
+ - str: The framework name
514
+ - str: The Python version
515
+ - str: The image tag
516
+ - str: If the TensorFlow image is script mode
517
+ """
511
518
sagemaker_pattern = re .compile (ECR_URI_PATTERN )
512
519
sagemaker_match = sagemaker_pattern .match (image_name )
513
520
if sagemaker_match is None :
514
521
return None , None , None , None
522
+
515
523
# extract framework, python version and image tag
516
524
# We must support both the legacy and current image name format.
517
525
name_pattern = re .compile (
518
- r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 # pylint: disable=line-too-long
526
+ r"""^(?:sagemaker(?:-rl)?-)?
527
+ (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?
528
+ (scriptmode|training)?
529
+ :(.*)-(.*?)-(py2|py3[67]?)$""" ,
530
+ re .VERBOSE ,
519
531
)
520
- legacy_name_pattern = re .compile (r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$" )
521
-
522
532
name_match = name_pattern .match (sagemaker_match .group (9 ))
523
- legacy_match = legacy_name_pattern .match (sagemaker_match .group (9 ))
524
-
525
533
if name_match is not None :
526
534
fw , scriptmode , ver , device , py = (
527
535
name_match .group (1 ),
@@ -531,20 +539,25 @@ def framework_name_from_image(image_name):
531
539
name_match .group (5 ),
532
540
)
533
541
return fw , py , "{}-{}-{}" .format (ver , device , py ), scriptmode
542
+
543
+ legacy_name_pattern = re .compile (r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$" )
544
+ legacy_match = legacy_name_pattern .match (sagemaker_match .group (9 ))
534
545
if legacy_match is not None :
535
546
return (legacy_match .group (1 ), legacy_match .group (2 ), legacy_match .group (4 ), None )
536
547
return None , None , None , None
537
548
538
549
539
550
def framework_version_from_tag (image_tag ):
540
551
"""Extract the framework version from the image tag.
552
+
541
553
Args:
542
554
image_tag (str): Image tag, which should take the form
543
555
'<framework_version>-<device>-<py_version>'
556
+
544
557
Returns:
545
558
str: The framework version.
546
559
"""
547
- tag_pattern = re .compile ("^(.*)-(cpu|gpu)-(py2|py3)$" )
560
+ tag_pattern = re .compile ("^(.*)-(cpu|gpu)-(py2|py3[67]? )$" )
548
561
tag_match = tag_pattern .match (image_tag )
549
562
return None if tag_match is None else tag_match .group (1 )
550
563
0 commit comments