1
+ import unittest
2
+ from unittest .mock import patch
3
+
4
+ from sagemaker .image_utils import get_latest_container_image
5
+
6
+
7
+ class TestImageUtils (unittest .TestCase ):
8
+ @patch ('sagemaker.image_utils.config_for_framework' )
9
+ @patch ('sagemaker.image_utils.retrieve' )
10
+ def test_get_latest_container_image (self ,
11
+ mock_image_retrieve ,
12
+ mock_config_for_framework ):
13
+ mock_config_for_framework .return_value = {
14
+ "inference" : {
15
+ "version_aliases" : {
16
+ "latest" : "1"
17
+ }
18
+ }
19
+ }
20
+ mock_image_retrieve .return_value = "latest-image"
21
+
22
+ image , version = get_latest_container_image ("xgboost" , "inference" )
23
+ assert image == "latest-image"
24
+ assert version == "1"
25
+
26
+ @patch ('sagemaker.image_utils.config_for_framework' )
27
+ @patch ('sagemaker.image_utils.retrieve' )
28
+ def test_get_latest_container_image_invalid_framework (self ,
29
+ mock_image_retrieve ,
30
+ mock_config_for_framework ):
31
+ mock_config_for_framework .side_effect = FileNotFoundError
32
+
33
+ with self .assertRaises (ValueError ) as e :
34
+ get_latest_container_image ("xgboost" , "inference" )
35
+ assert "No framework config for framework" in str (e .exception )
36
+
37
+ @patch ('sagemaker.image_utils.config_for_framework' )
38
+ @patch ('sagemaker.image_utils.retrieve' )
39
+ def test_get_latest_container_image_no_framework (self ,
40
+ mock_image_retrieve ,
41
+ mock_config_for_framework ):
42
+ mock_config_for_framework .return_value = {}
43
+
44
+ with self .assertRaises (ValueError ) as e :
45
+ get_latest_container_image ("xgboost" , "inference" )
46
+ assert "No framework config for framework" in str (e .exception )
0 commit comments