1+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4+ # may not use this file except in compliance with the License. A copy of
5+ # the License is located at
6+ #
7+ # http://aws.amazon.com/apache2.0/
8+ #
9+ # or in the "license" file accompanying this file. This file is
10+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+ # ANY KIND, either express or implied. See the License for the specific
12+ # language governing permissions and limitations under the License.
13+ from __future__ import absolute_import
14+ import pytest
15+ from unittest .mock import patch , Mock
16+ from sagemaker .jumpstart import utils
17+ from sagemaker .jumpstart .enums import JumpStartScriptScope , JumpStartConfigRankingName
18+ from sagemaker .jumpstart .factory .estimator import _add_config_name_to_kwargs
19+ from sagemaker .jumpstart .factory .model import (
20+ _add_config_name_to_init_kwargs ,
21+ _add_config_name_to_deploy_kwargs ,
22+ )
23+ from sagemaker .jumpstart .types import JumpStartEstimatorInitKwargs , JumpStartModelInitKwargs
24+
25+
26+ class TestAutoConfigResolution :
27+ """Test auto resolution of config names based on instance type."""
28+
29+ def create_mock_configs (self , scope ):
30+ """Create mock configs for testing with different supported instance types."""
31+ # Mock the config object structure
32+ config1 = Mock ()
33+ config1 .config_name = "config1"
34+ config1 .resolved_config = {
35+ "supported_inference_instance_types" : ["ml.g5.xlarge" , "ml.g5.2xlarge" ]
36+ if scope == JumpStartScriptScope .INFERENCE
37+ else [],
38+ "supported_training_instance_types" : ["ml.g5.xlarge" , "ml.g5.2xlarge" ]
39+ if scope == JumpStartScriptScope .TRAINING
40+ else [],
41+ }
42+
43+ config2 = Mock ()
44+ config2 .config_name = "config2"
45+ config2 .resolved_config = {
46+ "supported_inference_instance_types" : ["ml.p4d.24xlarge" , "ml.p5.48xlarge" ]
47+ if scope == JumpStartScriptScope .INFERENCE
48+ else [],
49+ "supported_training_instance_types" : ["ml.p4d.24xlarge" , "ml.p5.48xlarge" ]
50+ if scope == JumpStartScriptScope .TRAINING
51+ else [],
52+ }
53+
54+ # Config with no instance type restrictions
55+ config3 = Mock ()
56+ config3 .config_name = "config3"
57+ config3 .resolved_config = {
58+ "supported_inference_instance_types" : []
59+ if scope == JumpStartScriptScope .INFERENCE
60+ else [],
61+ "supported_training_instance_types" : []
62+ if scope == JumpStartScriptScope .TRAINING
63+ else [],
64+ }
65+
66+ # Mock config rankings
67+ ranking = Mock ()
68+ ranking .rankings = ["config1" , "config2" , "config3" ]
69+
70+ # Mock the metadata configs container
71+ configs = Mock ()
72+ configs .scope = scope
73+ configs .configs = {
74+ "config1" : config1 ,
75+ "config2" : config2 ,
76+ "config3" : config3 ,
77+ }
78+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
79+
80+ # Import the actual get_top_config_from_ranking method so we can test it
81+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
82+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
83+
84+ return configs
85+
86+ def test_get_top_config_from_ranking_with_matching_instance_type (self ):
87+ """Test that get_top_config_from_ranking returns config that supports the instance type."""
88+ configs = self .create_mock_configs (JumpStartScriptScope .INFERENCE )
89+
90+ # Test with instance type that matches config1
91+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
92+ assert result is not None
93+ assert result .config_name == "config1"
94+
95+ # Test with instance type that matches config2
96+ result = configs .get_top_config_from_ranking (instance_type = "ml.p4d.24xlarge" )
97+ assert result is not None
98+ assert result .config_name == "config2"
99+
100+ def test_get_top_config_from_ranking_with_no_matching_instance_type (self ):
101+ """Test behavior when no config supports the requested instance type."""
102+ configs = self .create_mock_configs (JumpStartScriptScope .INFERENCE )
103+
104+ # Test with instance type that doesn't match any config
105+ result = configs .get_top_config_from_ranking (instance_type = "ml.m5.xlarge" )
106+ assert result is not None
107+ assert result .config_name == "config3" # Should fall back to config with no restrictions
108+
109+ def test_get_top_config_from_ranking_without_instance_type (self ):
110+ """Test that get_top_config_from_ranking returns first ranked config when no instance type specified."""
111+ configs = self .create_mock_configs (JumpStartScriptScope .INFERENCE )
112+
113+ result = configs .get_top_config_from_ranking ()
114+ assert result is not None
115+ assert result .config_name == "config1" # First in ranking
116+
117+ def test_get_top_config_from_ranking_training_scope (self ):
118+ """Test get_top_config_from_ranking with training scope."""
119+ configs = self .create_mock_configs (JumpStartScriptScope .TRAINING )
120+
121+ # Test with training instance type
122+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
123+ assert result is not None
124+ assert result .config_name == "config1"
125+
126+ def test_get_top_config_from_ranking_with_object_resolved_config (self ):
127+ """Test get_top_config_from_ranking when resolved_config is an object (not dict)."""
128+ # Create a mock object with getattr support
129+ mock_resolved_config = Mock ()
130+ mock_resolved_config .supported_inference_instance_types = ["ml.g5.xlarge" ]
131+
132+ config = Mock ()
133+ config .config_name = "test_config"
134+ config .resolved_config = mock_resolved_config
135+
136+ ranking = Mock ()
137+ ranking .rankings = ["test_config" ]
138+
139+ configs = Mock ()
140+ configs .scope = JumpStartScriptScope .INFERENCE
141+ configs .configs = {"test_config" : config }
142+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
143+
144+ # Import the actual method
145+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
146+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
147+
148+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
149+ assert result is not None
150+ assert result .config_name == "test_config"
151+
152+ def test_get_top_config_from_ranking_empty_supported_instance_types (self ):
153+ """Test behavior when config has empty supported_instance_types list."""
154+ config = Mock ()
155+ config .config_name = "empty_config"
156+ config .resolved_config = {
157+ "supported_inference_instance_types" : [],
158+ }
159+
160+ ranking = Mock ()
161+ ranking .rankings = ["empty_config" ]
162+
163+ configs = Mock ()
164+ configs .scope = JumpStartScriptScope .INFERENCE
165+ configs .configs = {"empty_config" : config }
166+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
167+
168+ # Import the actual method
169+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
170+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
171+
172+ # Should return config even with empty list (no restrictions)
173+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
174+ assert result is not None
175+ assert result .config_name == "empty_config"
176+
177+ def test_instance_type_parameter_signature (self ):
178+ """Test that get_top_ranked_config_name function accepts instance_type parameter."""
179+ # Import and inspect the function signature
180+ import inspect
181+ from typing import Optional
182+ sig = inspect .signature (utils .get_top_ranked_config_name )
183+
184+ # Verify that instance_type parameter exists in the signature
185+ assert "instance_type" in sig .parameters
186+
187+ # Verify it's an optional parameter with None default
188+ instance_type_param = sig .parameters ["instance_type" ]
189+ assert instance_type_param .default is None
190+ assert instance_type_param .annotation == Optional [str ]
191+
192+ def test_get_top_config_from_ranking_preserves_existing_config_name (self ):
193+ """Test that existing config_name is preserved when already specified."""
194+ mock_get_config = Mock (return_value = "auto_selected" )
195+
196+ with patch ("sagemaker.jumpstart.utils.get_top_ranked_config_name" , mock_get_config ):
197+ kwargs = JumpStartEstimatorInitKwargs (
198+ model_id = "test-model" ,
199+ instance_type = "ml.g5.xlarge" ,
200+ config_name = "user_specified_config" ,
201+ )
202+
203+ result = _add_config_name_to_kwargs (kwargs )
204+
205+ # Should not call get_top_ranked_config_name when config_name already exists
206+ mock_get_config .assert_not_called ()
207+ assert result .config_name == "user_specified_config"
208+
209+ def test_config_ranking_respects_priority_with_instance_type_filter (self ):
210+ """Test that config ranking priority is respected when filtering by instance type."""
211+ # Create configs where config2 is ranked higher but config1 matches instance type
212+ config1 = Mock ()
213+ config1 .config_name = "config1"
214+ config1 .resolved_config = {"supported_inference_instance_types" : ["ml.g5.xlarge" ]}
215+
216+ config2 = Mock ()
217+ config2 .config_name = "config2"
218+ config2 .resolved_config = {"supported_inference_instance_types" : ["ml.p4d.24xlarge" ]}
219+
220+ # Rank config2 higher than config1
221+ ranking = Mock ()
222+ ranking .rankings = ["config2" , "config1" ]
223+
224+ configs = Mock ()
225+ configs .scope = JumpStartScriptScope .INFERENCE
226+ configs .configs = {"config1" : config1 , "config2" : config2 }
227+ configs .config_rankings = {JumpStartConfigRankingName .DEFAULT : ranking }
228+
229+ # Import the actual method
230+ from sagemaker .jumpstart .types import JumpStartMetadataConfigs
231+ configs .get_top_config_from_ranking = JumpStartMetadataConfigs .get_top_config_from_ranking .__get__ (configs )
232+
233+ # Even though config2 is ranked higher, config1 should be returned because it matches instance type
234+ result = configs .get_top_config_from_ranking (instance_type = "ml.g5.xlarge" )
235+ assert result is not None
236+ assert result .config_name == "config1"
0 commit comments