Skip to content

Commit 372b14f

Browse files
committed
get_latest_container_image Support for multiple formats of config jsons
1 parent 94537be commit 372b14f

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

src/sagemaker/image_utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import Optional, Tuple
22

33
from sagemaker.image_uris import config_for_framework, retrieve
4-
4+
from packaging.version import Version
55

66
def get_latest_container_image(framework: str,
77
image_scope: str,
8+
instance_type: Optional[str] = None,
9+
py_version: Optional[str] = None,
810
region: str = "us-west-2",
911
version: Optional[str] = None) -> Tuple[str, str]:
1012
try:
@@ -19,9 +21,23 @@ def get_latest_container_image(framework: str,
1921
version = _fetch_latest_version_from_config(framework_config, image_scope)
2022
image_uri = retrieve(framework=framework,
2123
region=region,
22-
version=version)
24+
version=version,
25+
instance_type=instance_type,
26+
py_version=py_version
27+
)
2328
return image_uri, version
2429

2530

2631
def _fetch_latest_version_from_config(framework_config: dict, image_scope: str) -> str:
27-
return framework_config.get(image_scope).get("version_aliases").get("latest")
32+
if image_scope in framework_config:
33+
if image_scope_config := framework_config[image_scope]:
34+
if version_aliases := image_scope_config["version_aliases"]:
35+
if latest_version := version_aliases["latest"]:
36+
return latest_version
37+
versions = list(framework_config["versions"].keys())
38+
top_version = versions[0]
39+
bottom_version = versions[-1]
40+
41+
if Version(top_version) >= Version(bottom_version):
42+
return top_version
43+
return bottom_version

tests/unit/test_image_utils.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,48 @@
55

66

77
class TestImageUtils(unittest.TestCase):
8+
89
@patch('sagemaker.image_utils.config_for_framework')
910
@patch('sagemaker.image_utils.retrieve')
1011
def test_get_latest_container_image(self,
1112
mock_image_retrieve,
1213
mock_config_for_framework):
14+
mock_config_for_framework.return_value = {
15+
"versions": {
16+
"24.03": {
17+
"registries": {
18+
"af-south-1": "626614931356",
19+
},
20+
"repository": "sagemaker-tritonserver",
21+
"tag_prefix": "24.03-py3"
22+
},
23+
"24.01": {
24+
"registries": {
25+
"af-south-1": "626614931356"
26+
},
27+
"repository": "sagemaker-tritonserver",
28+
"tag_prefix": "24.01-py3"
29+
},
30+
"23.12": {
31+
"registries": {
32+
"af-south-1": "626614931356"
33+
},
34+
"repository": "sagemaker-tritonserver",
35+
"tag_prefix": "23.12-py3"
36+
}
37+
}
38+
}
39+
mock_image_retrieve.return_value = "latest-image"
40+
41+
image, version = get_latest_container_image("xgboost", "inference")
42+
assert image == "latest-image"
43+
assert version == "24.03"
44+
45+
@patch('sagemaker.image_utils.config_for_framework')
46+
@patch('sagemaker.image_utils.retrieve')
47+
def test_get_latest_container_image_with_alias(self,
48+
mock_image_retrieve,
49+
mock_config_for_framework):
1350
mock_config_for_framework.return_value = {
1451
"inference": {
1552
"version_aliases": {
@@ -24,9 +61,7 @@ def test_get_latest_container_image(self,
2461
assert version == "1"
2562

2663
@patch('sagemaker.image_utils.config_for_framework')
27-
@patch('sagemaker.image_utils.retrieve')
2864
def test_get_latest_container_image_invalid_framework(self,
29-
mock_image_retrieve,
3065
mock_config_for_framework):
3166
mock_config_for_framework.side_effect = FileNotFoundError
3267

@@ -35,12 +70,10 @@ def test_get_latest_container_image_invalid_framework(self,
3570
assert "No framework config for framework" in str(e.exception)
3671

3772
@patch('sagemaker.image_utils.config_for_framework')
38-
@patch('sagemaker.image_utils.retrieve')
3973
def test_get_latest_container_image_no_framework(self,
40-
mock_image_retrieve,
41-
mock_config_for_framework):
74+
mock_config_for_framework):
4275
mock_config_for_framework.return_value = {}
4376

4477
with self.assertRaises(ValueError) as e:
4578
get_latest_container_image("xgboost", "inference")
46-
assert "No framework config for framework" in str(e.exception)
79+
assert "No framework config for framework" in str(e.exception)

0 commit comments

Comments
 (0)