@@ -410,6 +410,8 @@ def framework_name_from_image(image_uri):
410410 '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-rl-<fw>:<rl_toolkit><rl_version>-<device>-<py_ver>'
411411 current:
412412 '<account>.dkr.ecr.<region>.amazonaws.com/<fw>-<image_scope>:<fw_version>-<device>-<py_ver>'
413+ current:
414+ '<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-xgboost:<fw_version>-<container_version>'
413415
414416 Returns:
415417 tuple: A tuple containing:
@@ -450,6 +452,16 @@ def framework_name_from_image(image_uri):
450452 legacy_match = legacy_name_pattern .match (sagemaker_match .group (9 ))
451453 if legacy_match is not None :
452454 return (legacy_match .group (1 ), legacy_match .group (2 ), legacy_match .group (4 ), None )
455+
456+ # sagemaker-xgboost images are tagged with two aliases, e.g.:
457+ # 1. Long tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1-cpu-py3"
458+ # 2. Short tag: "315553699071.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.5-1"
459+ # Note 1: Both tags point to the same image
460+ # Note 2: Both tags have full GPU capabilities, despite "cpu" delineation in the long tag
461+ short_xgboost_tag_pattern = re .compile (r"^sagemaker-(xgboost):(.*)$" )
462+ short_xgboost_tag_match = short_xgboost_tag_pattern .match (sagemaker_match .group (9 ))
463+ if short_xgboost_tag_match is not None :
464+ return (short_xgboost_tag_match .group (1 ), "py3" , short_xgboost_tag_match .group (2 ), None )
453465 return None , None , None , None
454466
455467
@@ -459,12 +471,16 @@ def framework_version_from_tag(image_tag):
459471 Args:
460472 image_tag (str): Image tag, which should take the form
461473 '<framework_version>-<device>-<py_version>'
474+ '<xgboost_version>-<container_version>'
462475
463476 Returns:
464477 str: The framework version.
465478 """
466479 tag_pattern = re .compile (r"^(.*)-(cpu|gpu)-(py2|py3\d*)$" )
467480 tag_match = tag_pattern .match (image_tag )
481+ if tag_match is None :
482+ short_xgboost_tag_pattern = re .compile (r"^(\d\.\d+\-\d)$" )
483+ tag_match = short_xgboost_tag_pattern .match (image_tag )
468484 return None if tag_match is None else tag_match .group (1 )
469485
470486
0 commit comments