Skip to content

Commit 94537be

Browse files
committed
get_latest_container_image utils
1 parent 57e52ee commit 94537be

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

src/sagemaker/image_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Optional, Tuple
2+
3+
from sagemaker.image_uris import config_for_framework, retrieve
4+
5+
6+
def get_latest_container_image(framework: str,
7+
image_scope: str,
8+
region: str = "us-west-2",
9+
version: Optional[str] = None) -> Tuple[str, str]:
10+
try:
11+
framework_config = config_for_framework(framework)
12+
except FileNotFoundError:
13+
raise ValueError("Invalid framework {}".format(framework))
14+
15+
if not framework_config:
16+
raise ValueError("Invalid framework {}".format(framework))
17+
18+
if not version:
19+
version = _fetch_latest_version_from_config(framework_config, image_scope)
20+
image_uri = retrieve(framework=framework,
21+
region=region,
22+
version=version)
23+
return image_uri, version
24+
25+
26+
def _fetch_latest_version_from_config(framework_config: dict, image_scope: str) -> str:
27+
return framework_config.get(image_scope).get("version_aliases").get("latest")

tests/unit/test_image_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
from unittest.mock import patch
3+
4+
from sagemaker.image_utils import get_latest_container_image
5+
6+
7+
class TestImageUtils(unittest.TestCase):
8+
@patch('sagemaker.image_utils.config_for_framework')
9+
@patch('sagemaker.image_utils.retrieve')
10+
def test_get_latest_container_image(self,
11+
mock_image_retrieve,
12+
mock_config_for_framework):
13+
mock_config_for_framework.return_value = {
14+
"inference": {
15+
"version_aliases": {
16+
"latest": "1"
17+
}
18+
}
19+
}
20+
mock_image_retrieve.return_value = "latest-image"
21+
22+
image, version = get_latest_container_image("xgboost", "inference")
23+
assert image == "latest-image"
24+
assert version == "1"
25+
26+
@patch('sagemaker.image_utils.config_for_framework')
27+
@patch('sagemaker.image_utils.retrieve')
28+
def test_get_latest_container_image_invalid_framework(self,
29+
mock_image_retrieve,
30+
mock_config_for_framework):
31+
mock_config_for_framework.side_effect = FileNotFoundError
32+
33+
with self.assertRaises(ValueError) as e:
34+
get_latest_container_image("xgboost", "inference")
35+
assert "No framework config for framework" in str(e.exception)
36+
37+
@patch('sagemaker.image_utils.config_for_framework')
38+
@patch('sagemaker.image_utils.retrieve')
39+
def test_get_latest_container_image_no_framework(self,
40+
mock_image_retrieve,
41+
mock_config_for_framework):
42+
mock_config_for_framework.return_value = {}
43+
44+
with self.assertRaises(ValueError) as e:
45+
get_latest_container_image("xgboost", "inference")
46+
assert "No framework config for framework" in str(e.exception)

0 commit comments

Comments
 (0)