1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ from __future__ import absolute_import
14
+ import pytest
15
+ from unittest .mock import patch , Mock
16
+ from sagemaker .jumpstart import utils
17
+ from sagemaker .jumpstart .enums import JumpStartScriptScope , JumpStartConfigRankingName
18
+ from sagemaker .jumpstart .factory .estimator import _add_config_name_to_kwargs
19
+ from sagemaker .jumpstart .factory .model import (
20
+ _add_config_name_to_init_kwargs ,
21
+ _add_config_name_to_deploy_kwargs ,
22
+ )
23
+ from sagemaker .jumpstart .types import JumpStartEstimatorInitKwargs , JumpStartModelInitKwargs
24
+
25
+
26
+ class TestAutoConfigResolution :
27
+ """Test auto resolution of config names based on instance type."""
28
+
29
+ def create_mock_configs (self , scope ):
30
+ """Create mock configs for testing with different supported instance types."""
31
+ # Mock the config object structure
32
+ config1 = Mock ()
33
+ config1 .config_name = "config1"
34
+ config1 .resolved_config = {
35
+ "supported_inference_instance_types" : ["ml.g5.xlarge" , "ml.g5.2xlarge" ]
36
+ if scope == JumpStartScriptScope .INFERENCE
37
+ else [],
38
+ "supported_training_instance_types" : ["ml.g5.xlarge" , "ml.g5.2xlarge" ]
39
+ if scope == JumpStartScriptScope .TRAINING
40
+ else [],
41
+ }
42
+
43
+ config2 = Mock ()
44
+ config2 .config_name = "config2"
45
+ config2 .resolved_config = {
46
+ "supported_inference_instance_types" : ["ml.p4d.24xlarge" , "ml.p5.48xlarge" ]
47
+ if scope == JumpStartScriptScope .INFERENCE
48
+ else [],
49
+ "supported_training_instance_types" : ["ml.p4d.24xlarge" , "ml.p5.48xlarge" ]
50
+ if scope == JumpStartScriptScope .TRAINING
51
+ else [],
52
+ }
53
+
54
+ # Config with no instance type restrictions
55
+ config3 = Mock ()
56
+ config3 .config_name = "config3"
57
+ config3 .resolved_config = {
58
+ "supported_inference_instance_types" : []
59
+ if scope == JumpStartScriptScope .INFERENCE
60
+ else [],
61
+ "supported_training_instance_types" : []
62
+ if scope == JumpStartScriptScope .TRAINING
63
+ else [],
64
+ }
65
+
66
+ # Mock config rankings
67
+ ranking = Mock ()
68
+ ranking .rankings = ["config1" , "config2" , "config3" ]
69
+
70
+ # Mock the metadata configs container
71
+ configs = Mock ()
72
+ configs .scope = scope
73
+ configs .configs = {
74
+ "config1" : config1 ,
75
+ "config2" : config2 ,
76
+ "config3" : config3 ,
77
+ }
78
+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
79
+
80
+ # Import the actual get_top_config_from_ranking method so we can test it
81
+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
82
+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
83
+
84
+ return configs
85
+
86
+ def test_get_top_config_from_ranking_with_matching_instance_type (self ):
87
+ """Test that get_top_config_from_ranking returns config that supports the instance type."""
88
+ configs = self .create_mock_configs (JumpStartScriptScope .INFERENCE )
89
+
90
+ # Test with instance type that matches config1
91
+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
92
+ assert result is not None
93
+ assert result .config_name == "config1"
94
+
95
+ # Test with instance type that matches config2
96
+ result = configs .get_top_config_from_ranking (instance_type = "ml.p4d.24xlarge" )
97
+ assert result is not None
98
+ assert result .config_name == "config2"
99
+
100
+ def test_get_top_config_from_ranking_with_no_matching_instance_type (self ):
101
+ """Test behavior when no config supports the requested instance type."""
102
+ configs = self .create_mock_configs (JumpStartScriptScope .INFERENCE )
103
+
104
+ # Test with instance type that doesn't match any config
105
+ result = configs .get_top_config_from_ranking (instance_type = "ml.m5.xlarge" )
106
+ assert result is not None
107
+ assert result .config_name == "config3" # Should fall back to config with no restrictions
108
+
109
+ def test_get_top_config_from_ranking_without_instance_type (self ):
110
+ """Test that get_top_config_from_ranking returns first ranked config when no instance type specified."""
111
+ configs = self .create_mock_configs (JumpStartScriptScope .INFERENCE )
112
+
113
+ result = configs .get_top_config_from_ranking ()
114
+ assert result is not None
115
+ assert result .config_name == "config1" # First in ranking
116
+
117
+ def test_get_top_config_from_ranking_training_scope (self ):
118
+ """Test get_top_config_from_ranking with training scope."""
119
+ configs = self .create_mock_configs (JumpStartScriptScope .TRAINING )
120
+
121
+ # Test with training instance type
122
+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
123
+ assert result is not None
124
+ assert result .config_name == "config1"
125
+
126
+ def test_get_top_config_from_ranking_with_object_resolved_config (self ):
127
+ """Test get_top_config_from_ranking when resolved_config is an object (not dict)."""
128
+ # Create a mock object with getattr support
129
+ mock_resolved_config = Mock ()
130
+ mock_resolved_config .supported_inference_instance_types = ["ml.g5.xlarge" ]
131
+
132
+ config = Mock ()
133
+ config .config_name = "test_config"
134
+ config .resolved_config = mock_resolved_config
135
+
136
+ ranking = Mock ()
137
+ ranking .rankings = ["test_config" ]
138
+
139
+ configs = Mock ()
140
+ configs .scope = JumpStartScriptScope .INFERENCE
141
+ configs .configs = {"test_config" : config }
142
+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
143
+
144
+ # Import the actual method
145
+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
146
+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
147
+
148
+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
149
+ assert result is not None
150
+ assert result .config_name == "test_config"
151
+
152
+ def test_get_top_config_from_ranking_empty_supported_instance_types (self ):
153
+ """Test behavior when config has empty supported_instance_types list."""
154
+ config = Mock ()
155
+ config .config_name = "empty_config"
156
+ config .resolved_config = {
157
+ "supported_inference_instance_types" : [],
158
+ }
159
+
160
+ ranking = Mock ()
161
+ ranking .rankings = ["empty_config" ]
162
+
163
+ configs = Mock ()
164
+ configs .scope = JumpStartScriptScope .INFERENCE
165
+ configs .configs = {"empty_config" : config }
166
+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
167
+
168
+ # Import the actual method
169
+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
170
+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
171
+
172
+ # Should return config even with empty list (no restrictions)
173
+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
174
+ assert result is not None
175
+ assert result .config_name == "empty_config"
176
+
177
+ def test_instance_type_parameter_signature (self ):
178
+ """Test that get_top_ranked_config_name function accepts instance_type parameter."""
179
+ # Import and inspect the function signature
180
+ import inspect
181
+ from typing import Optional
182
+ sig = inspect .signature (utils .get_top_ranked_config_name )
183
+
184
+ # Verify that instance_type parameter exists in the signature
185
+ assert "instance_type" in sig .parameters
186
+
187
+ # Verify it's an optional parameter with None default
188
+ instance_type_param = sig .parameters ["instance_type" ]
189
+ assert instance_type_param .default is None
190
+ assert instance_type_param .annotation == Optional [str ]
191
+
192
+ def test_get_top_config_from_ranking_preserves_existing_config_name (self ):
193
+ """Test that existing config_name is preserved when already specified."""
194
+ mock_get_config = Mock (return_value = "auto_selected" )
195
+
196
+ with patch ("sagemaker.jumpstart.utils.get_top_ranked_config_name" , mock_get_config ):
197
+ kwargs = JumpStartEstimatorInitKwargs (
198
+ model_id = "test-model" ,
199
+ instance_type = "ml.g5.xlarge" ,
200
+ config_name = "user_specified_config" ,
201
+ )
202
+
203
+ result = _add_config_name_to_kwargs (kwargs )
204
+
205
+ # Should not call get_top_ranked_config_name when config_name already exists
206
+ mock_get_config .assert_not_called ()
207
+ assert result .config_name == "user_specified_config"
208
+
209
+ def test_config_ranking_respects_priority_with_instance_type_filter (self ):
210
+ """Test that config ranking priority is respected when filtering by instance type."""
211
+ # Create configs where config2 is ranked higher but config1 matches instance type
212
+ config1 = Mock ()
213
+ config1 .config_name = "config1"
214
+ config1 .resolved_config = {"supported_inference_instance_types" : ["ml.g5.xlarge" ]}
215
+
216
+ config2 = Mock ()
217
+ config2 .config_name = "config2"
218
+ config2 .resolved_config = {"supported_inference_instance_types" : ["ml.p4d.24xlarge" ]}
219
+
220
+ # Rank config2 higher than config1
221
+ ranking = Mock ()
222
+ ranking .rankings = ["config2" , "config1" ]
223
+
224
+ configs = Mock ()
225
+ configs .scope = JumpStartScriptScope .INFERENCE
226
+ configs .configs = {"config1" : config1 , "config2" : config2 }
227
+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
228
+
229
+ # Import the actual method
230
+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
231
+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
232
+
233
+ # Even though config2 is ranked higher, config1 should be returned because it matches instance type
234
+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
235
+ assert result is not None
236
+ assert result .config_name == "config1"
0 commit comments