Skip to content

Commit 2976c5c

Browse files
committed
tests: testing changes
1 parent d6c280f commit 2976c5c

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import pytest
15+
from unittest.mock import patch, Mock
16+
from sagemaker.jumpstart import utils
17+
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartConfigRankingName
18+
from sagemaker.jumpstart.factory.estimator import _add_config_name_to_kwargs
19+
from sagemaker.jumpstart.factory.model import (
20+
_add_config_name_to_init_kwargs,
21+
_add_config_name_to_deploy_kwargs,
22+
)
23+
from sagemaker.jumpstart.types import JumpStartEstimatorInitKwargs, JumpStartModelInitKwargs
24+
25+
26+
class TestAutoConfigResolution:
27+
"""Test auto resolution of config names based on instance type."""
28+
29+
def create_mock_configs(self, scope):
30+
"""Create mock configs for testing with different supported instance types."""
31+
# Mock the config object structure
32+
config1 = Mock()
33+
config1.config_name = "config1"
34+
config1.resolved_config = {
35+
"supported_inference_instance_types": ["ml.g5.xlarge", "ml.g5.2xlarge"]
36+
if scope == JumpStartScriptScope.INFERENCE
37+
else [],
38+
"supported_training_instance_types": ["ml.g5.xlarge", "ml.g5.2xlarge"]
39+
if scope == JumpStartScriptScope.TRAINING
40+
else [],
41+
}
42+
43+
config2 = Mock()
44+
config2.config_name = "config2"
45+
config2.resolved_config = {
46+
"supported_inference_instance_types": ["ml.p4d.24xlarge", "ml.p5.48xlarge"]
47+
if scope == JumpStartScriptScope.INFERENCE
48+
else [],
49+
"supported_training_instance_types": ["ml.p4d.24xlarge", "ml.p5.48xlarge"]
50+
if scope == JumpStartScriptScope.TRAINING
51+
else [],
52+
}
53+
54+
# Config with no instance type restrictions
55+
config3 = Mock()
56+
config3.config_name = "config3"
57+
config3.resolved_config = {
58+
"supported_inference_instance_types": []
59+
if scope == JumpStartScriptScope.INFERENCE
60+
else [],
61+
"supported_training_instance_types": []
62+
if scope == JumpStartScriptScope.TRAINING
63+
else [],
64+
}
65+
66+
# Mock config rankings
67+
ranking = Mock()
68+
ranking.rankings = ["config1", "config2", "config3"]
69+
70+
# Mock the metadata configs container
71+
configs = Mock()
72+
configs.scope = scope
73+
configs.configs = {
74+
"config1": config1,
75+
"config2": config2,
76+
"config3": config3,
77+
}
78+
configs.config_rankings = {JumpStartConfigRankingName.DEFAULT: ranking}
79+
80+
# Import the actual get_top_config_from_ranking method so we can test it
81+
from sagemaker.jumpstart.types import JumpStartMetadataConfigs
82+
configs.get_top_config_from_ranking = JumpStartMetadataConfigs.get_top_config_from_ranking.__get__(configs)
83+
84+
return configs
85+
86+
def test_get_top_config_from_ranking_with_matching_instance_type(self):
87+
"""Test that get_top_config_from_ranking returns config that supports the instance type."""
88+
configs = self.create_mock_configs(JumpStartScriptScope.INFERENCE)
89+
90+
# Test with instance type that matches config1
91+
result = configs.get_top_config_from_ranking(instance_type="ml.g5.xlarge")
92+
assert result is not None
93+
assert result.config_name == "config1"
94+
95+
# Test with instance type that matches config2
96+
result = configs.get_top_config_from_ranking(instance_type="ml.p4d.24xlarge")
97+
assert result is not None
98+
assert result.config_name == "config2"
99+
100+
def test_get_top_config_from_ranking_with_no_matching_instance_type(self):
101+
"""Test behavior when no config supports the requested instance type."""
102+
configs = self.create_mock_configs(JumpStartScriptScope.INFERENCE)
103+
104+
# Test with instance type that doesn't match any config
105+
result = configs.get_top_config_from_ranking(instance_type="ml.m5.xlarge")
106+
assert result is not None
107+
assert result.config_name == "config3" # Should fall back to config with no restrictions
108+
109+
def test_get_top_config_from_ranking_without_instance_type(self):
110+
"""Test that get_top_config_from_ranking returns first ranked config when no instance type specified."""
111+
configs = self.create_mock_configs(JumpStartScriptScope.INFERENCE)
112+
113+
result = configs.get_top_config_from_ranking()
114+
assert result is not None
115+
assert result.config_name == "config1" # First in ranking
116+
117+
def test_get_top_config_from_ranking_training_scope(self):
118+
"""Test get_top_config_from_ranking with training scope."""
119+
configs = self.create_mock_configs(JumpStartScriptScope.TRAINING)
120+
121+
# Test with training instance type
122+
result = configs.get_top_config_from_ranking(instance_type="ml.g5.xlarge")
123+
assert result is not None
124+
assert result.config_name == "config1"
125+
126+
def test_get_top_config_from_ranking_with_object_resolved_config(self):
127+
"""Test get_top_config_from_ranking when resolved_config is an object (not dict)."""
128+
# Create a mock object with getattr support
129+
mock_resolved_config = Mock()
130+
mock_resolved_config.supported_inference_instance_types = ["ml.g5.xlarge"]
131+
132+
config = Mock()
133+
config.config_name = "test_config"
134+
config.resolved_config = mock_resolved_config
135+
136+
ranking = Mock()
137+
ranking.rankings = ["test_config"]
138+
139+
configs = Mock()
140+
configs.scope = JumpStartScriptScope.INFERENCE
141+
configs.configs = {"test_config": config}
142+
configs.config_rankings = {JumpStartConfigRankingName.DEFAULT: ranking}
143+
144+
# Import the actual method
145+
from sagemaker.jumpstart.types import JumpStartMetadataConfigs
146+
configs.get_top_config_from_ranking = JumpStartMetadataConfigs.get_top_config_from_ranking.__get__(configs)
147+
148+
result = configs.get_top_config_from_ranking(instance_type="ml.g5.xlarge")
149+
assert result is not None
150+
assert result.config_name == "test_config"
151+
152+
def test_get_top_config_from_ranking_empty_supported_instance_types(self):
153+
"""Test behavior when config has empty supported_instance_types list."""
154+
config = Mock()
155+
config.config_name = "empty_config"
156+
config.resolved_config = {
157+
"supported_inference_instance_types": [],
158+
}
159+
160+
ranking = Mock()
161+
ranking.rankings = ["empty_config"]
162+
163+
configs = Mock()
164+
configs.scope = JumpStartScriptScope.INFERENCE
165+
configs.configs = {"empty_config": config}
166+
configs.config_rankings = {JumpStartConfigRankingName.DEFAULT: ranking}
167+
168+
# Import the actual method
169+
from sagemaker.jumpstart.types import JumpStartMetadataConfigs
170+
configs.get_top_config_from_ranking = JumpStartMetadataConfigs.get_top_config_from_ranking.__get__(configs)
171+
172+
# Should return config even with empty list (no restrictions)
173+
result = configs.get_top_config_from_ranking(instance_type="ml.g5.xlarge")
174+
assert result is not None
175+
assert result.config_name == "empty_config"
176+
177+
def test_instance_type_parameter_signature(self):
178+
"""Test that get_top_ranked_config_name function accepts instance_type parameter."""
179+
# Import and inspect the function signature
180+
import inspect
181+
from typing import Optional
182+
sig = inspect.signature(utils.get_top_ranked_config_name)
183+
184+
# Verify that instance_type parameter exists in the signature
185+
assert "instance_type" in sig.parameters
186+
187+
# Verify it's an optional parameter with None default
188+
instance_type_param = sig.parameters["instance_type"]
189+
assert instance_type_param.default is None
190+
assert instance_type_param.annotation == Optional[str]
191+
192+
def test_get_top_config_from_ranking_preserves_existing_config_name(self):
193+
"""Test that existing config_name is preserved when already specified."""
194+
mock_get_config = Mock(return_value="auto_selected")
195+
196+
with patch("sagemaker.jumpstart.utils.get_top_ranked_config_name", mock_get_config):
197+
kwargs = JumpStartEstimatorInitKwargs(
198+
model_id="test-model",
199+
instance_type="ml.g5.xlarge",
200+
config_name="user_specified_config",
201+
)
202+
203+
result = _add_config_name_to_kwargs(kwargs)
204+
205+
# Should not call get_top_ranked_config_name when config_name already exists
206+
mock_get_config.assert_not_called()
207+
assert result.config_name == "user_specified_config"
208+
209+
def test_config_ranking_respects_priority_with_instance_type_filter(self):
210+
"""Test that config ranking priority is respected when filtering by instance type."""
211+
# Create configs where config2 is ranked higher but config1 matches instance type
212+
config1 = Mock()
213+
config1.config_name = "config1"
214+
config1.resolved_config = {"supported_inference_instance_types": ["ml.g5.xlarge"]}
215+
216+
config2 = Mock()
217+
config2.config_name = "config2"
218+
config2.resolved_config = {"supported_inference_instance_types": ["ml.p4d.24xlarge"]}
219+
220+
# Rank config2 higher than config1
221+
ranking = Mock()
222+
ranking.rankings = ["config2", "config1"]
223+
224+
configs = Mock()
225+
configs.scope = JumpStartScriptScope.INFERENCE
226+
configs.configs = {"config1": config1, "config2": config2}
227+
configs.config_rankings = {JumpStartConfigRankingName.DEFAULT: ranking}
228+
229+
# Import the actual method
230+
from sagemaker.jumpstart.types import JumpStartMetadataConfigs
231+
configs.get_top_config_from_ranking = JumpStartMetadataConfigs.get_top_config_from_ranking.__get__(configs)
232+
233+
# Even though config2 is ranked higher, config1 should be returned because it matches instance type
234+
result = configs.get_top_config_from_ranking(instance_type="ml.g5.xlarge")
235+
assert result is not None
236+
assert result.config_name == "config1"

0 commit comments

Comments
 (0)