Skip to content

Commit 7860704

Browse files
committed
add deep UTs to catch regressions and test E2E fully and more practically
1 parent c5eac9a commit 7860704

File tree

2 files changed

+243
-1
lines changed

2 files changed

+243
-1
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1662,7 +1662,7 @@ def deploy(
16621662
vpc_config=self.vpc_config,
16631663
enable_network_isolation=self._enable_network_isolation,
16641664
role=self.role,
1665-
live_logging=endpoint_logging,
1665+
live_logging=False, # TODO: enable when IC supports this
16661666
wait=wait,
16671667
)
16681668

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
from unittest.mock import MagicMock, patch, ANY
15+
16+
from sagemaker.session import Session
17+
from sagemaker.serve.builder.model_builder import ModelBuilder
18+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
19+
from sagemaker.resource_requirements import ResourceRequirements
20+
21+
ROLE_NAME = "SageMakerRole"
22+
23+
24+
def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_expected(
25+
sagemaker_session,
26+
):
27+
with (
28+
patch.object(Session, "create_model", return_value="mock_model") as mock_create_model,
29+
patch.object(
30+
Session, "endpoint_from_production_variants"
31+
) as mock_endpoint_from_production_variants,
32+
):
33+
iam_client = sagemaker_session.boto_session.client("iam")
34+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
35+
36+
schema_builder = SchemaBuilder("test", "test")
37+
model_builder = ModelBuilder(
38+
model="meta-textgeneration-llama-3-1-8b-instruct",
39+
schema_builder=schema_builder,
40+
sagemaker_session=sagemaker_session,
41+
role_arn=role_arn,
42+
)
43+
44+
optimized_model = model_builder.optimize(
45+
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
46+
speculative_decoding_config={
47+
"ModelProvider": "JumpStart",
48+
"ModelID": "meta-textgeneration-llama-3-2-1b",
49+
"AcceptEula": True,
50+
},
51+
accept_eula=True,
52+
)
53+
54+
optimized_model.deploy()
55+
56+
mock_create_model.assert_called_once_with(
57+
name=ANY,
58+
role=ANY,
59+
container_defs={
60+
"Image": ANY,
61+
"Environment": {
62+
"SAGEMAKER_PROGRAM": "inference.py",
63+
"ENDPOINT_SERVER_TIMEOUT": "3600",
64+
"MODEL_CACHE_ROOT": "/opt/ml/model",
65+
"SAGEMAKER_ENV": "1",
66+
"HF_MODEL_ID": "/opt/ml/model",
67+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
68+
"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/",
69+
},
70+
"AdditionalModelDataSources": [
71+
{
72+
"ChannelName": "draft_model",
73+
"S3DataSource": {
74+
"S3Uri": ANY,
75+
"S3DataType": "S3Prefix",
76+
"CompressionType": "None",
77+
"ModelAccessConfig": {"AcceptEula": True},
78+
},
79+
}
80+
],
81+
"ModelDataSource": {
82+
"S3DataSource": {
83+
"S3Uri": ANY,
84+
"S3DataType": "S3Prefix",
85+
"CompressionType": "None",
86+
"ModelAccessConfig": {"AcceptEula": True},
87+
}
88+
},
89+
},
90+
vpc_config=None,
91+
enable_network_isolation=True,
92+
tags=ANY,
93+
)
94+
mock_endpoint_from_production_variants.assert_called_once()
95+
96+
97+
def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_expected(
98+
sagemaker_session,
99+
):
100+
with (
101+
patch.object(
102+
Session,
103+
"wait_for_optimization_job",
104+
return_value={"OptimizationJobName": "mock_optimization_job"},
105+
),
106+
patch.object(Session, "create_model", return_value="mock_model") as mock_create_model,
107+
patch.object(
108+
Session, "endpoint_from_production_variants", return_value="mock_endpoint_name"
109+
) as mock_endpoint_from_production_variants,
110+
patch.object(Session, "create_inference_component") as mock_create_inference_component,
111+
):
112+
iam_client = sagemaker_session.boto_session.client("iam")
113+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
114+
115+
sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()
116+
117+
schema_builder = SchemaBuilder("test", "test")
118+
model_builder = ModelBuilder(
119+
model="meta-textgeneration-llama-3-1-8b-instruct",
120+
schema_builder=schema_builder,
121+
sagemaker_session=sagemaker_session,
122+
role_arn=role_arn,
123+
)
124+
125+
optimized_model = model_builder.optimize(
126+
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
127+
sharding_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "8"}},
128+
accept_eula=True,
129+
)
130+
131+
optimized_model.deploy(
132+
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
133+
)
134+
135+
mock_create_model.assert_called_once_with(
136+
name=ANY,
137+
role=ANY,
138+
container_defs={
139+
"Image": ANY,
140+
"Environment": {
141+
"SAGEMAKER_PROGRAM": "inference.py",
142+
"ENDPOINT_SERVER_TIMEOUT": "3600",
143+
"MODEL_CACHE_ROOT": "/opt/ml/model",
144+
"SAGEMAKER_ENV": "1",
145+
"HF_MODEL_ID": "/opt/ml/model",
146+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
147+
"OPTION_TENSOR_PARALLEL_DEGREE": "8",
148+
},
149+
"ModelDataSource": {
150+
"S3DataSource": {
151+
"S3Uri": ANY,
152+
"S3DataType": "S3Prefix",
153+
"CompressionType": "None",
154+
"ModelAccessConfig": {"AcceptEula": True},
155+
}
156+
},
157+
},
158+
vpc_config=None,
159+
enable_network_isolation=False, # should be set to false
160+
tags=ANY,
161+
)
162+
mock_endpoint_from_production_variants.assert_called_once_with(
163+
name=ANY,
164+
production_variants=ANY,
165+
tags=ANY,
166+
kms_key=ANY,
167+
vpc_config=ANY,
168+
enable_network_isolation=False,
169+
role=ANY,
170+
live_logging=False, # this should be set to false for IC
171+
wait=True,
172+
)
173+
mock_create_inference_component.assert_called_once()
174+
175+
176+
def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are_expected(
177+
sagemaker_session,
178+
):
179+
with (
180+
patch.object(
181+
Session,
182+
"wait_for_optimization_job",
183+
return_value={"OptimizationJobName": "mock_optimization_job"},
184+
),
185+
patch.object(Session, "create_model", return_value="mock_model") as mock_create_model,
186+
patch.object(
187+
Session, "endpoint_from_production_variants", return_value="mock_endpoint_name"
188+
) as mock_endpoint_from_production_variants,
189+
):
190+
iam_client = sagemaker_session.boto_session.client("iam")
191+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
192+
193+
sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()
194+
195+
schema_builder = SchemaBuilder("test", "test")
196+
model_builder = ModelBuilder(
197+
model="meta-textgeneration-llama-3-1-8b-instruct",
198+
schema_builder=schema_builder,
199+
sagemaker_session=sagemaker_session,
200+
role_arn=role_arn,
201+
)
202+
203+
optimized_model = model_builder.optimize(
204+
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
205+
quantization_config={
206+
"OverrideEnvironment": {
207+
"OPTION_QUANTIZE": "fp8",
208+
},
209+
},
210+
accept_eula=True,
211+
)
212+
213+
optimized_model.deploy()
214+
215+
mock_create_model.assert_called_once_with(
216+
name=ANY,
217+
role=ANY,
218+
container_defs={
219+
"Image": ANY,
220+
"Environment": {
221+
"SAGEMAKER_PROGRAM": "inference.py",
222+
"ENDPOINT_SERVER_TIMEOUT": "3600",
223+
"MODEL_CACHE_ROOT": "/opt/ml/model",
224+
"SAGEMAKER_ENV": "1",
225+
"HF_MODEL_ID": "/opt/ml/model",
226+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
227+
"OPTION_QUANTIZE": "fp8",
228+
},
229+
"ModelDataSource": {
230+
"S3DataSource": {
231+
"S3Uri": ANY,
232+
"S3DataType": "S3Prefix",
233+
"CompressionType": "None",
234+
"ModelAccessConfig": {"AcceptEula": True},
235+
}
236+
},
237+
},
238+
vpc_config=None,
239+
enable_network_isolation=True, # should be set to false
240+
tags=ANY,
241+
)
242+
mock_endpoint_from_production_variants.assert_called_once()

0 commit comments

Comments
 (0)