Skip to content

Commit 114a716

Browse files
makungaj1Jonathan Makunga
andauthored
Bug fixes (#1496)
* Bug fixes * refcatore * ENV update * Remove code duplication * Fix Integ tests * Fix MB EULA bug --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 9a3f6ca commit 114a716

File tree

5 files changed

+127
-61
lines changed

5 files changed

+127
-61
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_extracts_and_validates_speculative_model_source,
4646
_generate_channel_name,
4747
_generate_additional_model_data_sources,
48+
_is_s3_uri,
4849
)
4950
from sagemaker.serve.utils.predictors import (
5051
DjlLocalModePredictor,
@@ -750,6 +751,8 @@ def _optimize_for_jumpstart(
750751

751752
if pysdk_model_env_vars:
752753
self.pysdk_model.env.update(pysdk_model_env_vars)
754+
if accept_eula:
755+
self.pysdk_model.accept_eula = accept_eula
753756

754757
if quantization_config or compilation_config:
755758
return create_optimization_job_args
@@ -787,8 +790,9 @@ def _set_additional_model_source(
787790
if speculative_decoding_config:
788791
model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config)
789792
channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources)
793+
speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}"
790794

791-
if model_provider.lower() == "sagemaker":
795+
if model_provider == "sagemaker":
792796
additional_model_data_sources = self.pysdk_model.deployment_config.get(
793797
"DeploymentArgs", {}
794798
).get("AdditionalDataSources")
@@ -805,27 +809,31 @@ def _set_additional_model_source(
805809
raise ValueError(
806810
"Cannot find deployment config compatible for optimization job."
807811
)
808-
809-
self.pysdk_model.add_tags(
810-
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"},
811-
)
812812
else:
813-
s3_uri = _extracts_and_validates_speculative_model_source(
813+
model_source = _extracts_and_validates_speculative_model_source(
814814
speculative_decoding_config
815815
)
816816

817-
self.pysdk_model.additional_model_data_sources = (
818-
_generate_additional_model_data_sources(s3_uri, channel_name, accept_eula)
819-
)
820-
self.pysdk_model.add_tags(
821-
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "customer"},
822-
)
817+
if _is_s3_uri(model_source):
818+
self.pysdk_model.additional_model_data_sources = (
819+
_generate_additional_model_data_sources(
820+
model_source, channel_name, accept_eula
821+
)
822+
)
823+
else:
824+
speculative_draft_model = model_source
823825

824-
speculative_draft_model = f"/opt/ml/additional-model-data-sources/{channel_name}"
825826
self.pysdk_model.env = _update_environment_variables(
826827
self.pysdk_model.env,
827828
{"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model},
828829
)
830+
self.pysdk_model.add_tags(
831+
{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider},
832+
)
833+
if accept_eula and isinstance(self.pysdk_model.model_data, dict):
834+
self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = {
835+
"AcceptEula": True
836+
}
829837

830838
def _find_compatible_deployment_config(
831839
self, speculative_decoding_config: Optional[Dict] = None

src/sagemaker/serve/builder/model_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ def _prepare_for_mode(self, should_upload_artifacts: bool = False):
408408
getattr(self, "model_hub", None) == ModelHub.JUMPSTART,
409409
should_upload=should_upload_artifacts,
410410
)
411-
self.env_vars.update(env_vars_sagemaker)
411+
if env_vars_sagemaker:
412+
self.env_vars.update(env_vars_sagemaker)
412413
return self.s3_upload_path, env_vars_sagemaker
413414
if self.mode == Mode.LOCAL_CONTAINER:
414415
# init the LocalContainerMode object
@@ -1026,6 +1027,12 @@ def _model_builder_optimize_wrapper(
10261027
)
10271028

10281029
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
1030+
1031+
if instance_type:
1032+
self.instance_type = instance_type
1033+
if role:
1034+
self.role = role
1035+
10291036
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
10301037
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
10311038

src/sagemaker/serve/mode/sagemaker_endpoint_mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def prepare(
7070
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
7171
) from e
7272

73-
upload_artifacts = None
73+
upload_artifacts = None, None
7474
if self.model_server == ModelServer.TORCHSERVE:
7575
upload_artifacts = self._upload_torchserve_artifacts(
7676
model_path=model_path,

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,6 @@
2525
logger = logging.getLogger(__name__)
2626

2727

28-
def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool:
29-
"""Checks whether an instance is compatible with Inferentia.
30-
31-
Args:
32-
instance_type (str): The instance type used for the compilation job.
33-
34-
Returns:
35-
bool: Whether the given instance type is Inferentia or Trainium.
36-
"""
37-
if isinstance(instance_type, str):
38-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
39-
if match:
40-
if match[1].startswith("inf") or match[1].startswith("trn"):
41-
return True
42-
return False
43-
44-
4528
def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool:
4629
"""Checks whether an instance is compatible with an optimization job.
4730
@@ -69,13 +52,16 @@ def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -
6952
"""
7053
recommended_image_uri = optimization_response["OptimizationOutput"]["RecommendedInferenceImage"]
7154
optimized_environment = optimization_response["OptimizationEnvironment"]
72-
s3_uri = optimization_response["ModelSource"]["S3"]
55+
s3_uri = optimization_response["OutputConfig"]["S3OutputLocation"]
7356
deployment_instance_type = optimization_response["DeploymentInstanceType"]
7457

7558
if recommended_image_uri:
7659
pysdk_model.image_uri = recommended_image_uri
7760
if optimized_environment:
78-
pysdk_model.env = optimized_environment
61+
if pysdk_model.env:
62+
pysdk_model.env.update(optimized_environment)
63+
else:
64+
pysdk_model.env = optimized_environment
7965
if s3_uri:
8066
pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri
8167
if deployment_instance_type:
@@ -258,3 +244,18 @@ def _generate_additional_model_data_sources(
258244
additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"ACCEPT_EULA": True}
259245

260246
return [additional_model_data_source]
247+
248+
249+
def _is_s3_uri(s3_uri: Optional[str]) -> bool:
250+
"""Checks whether an S3 URI is valid.
251+
252+
Args:
253+
s3_uri (Optional[str]): The S3 URI.
254+
255+
Returns:
256+
bool: Whether the S3 URI is valid.
257+
"""
258+
if s3_uri is None:
259+
return False
260+
261+
return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,21 @@
1919
from sagemaker.enums import Tag
2020
from sagemaker.serve.utils.optimize_utils import (
2121
_generate_optimized_model,
22-
_is_inferentia_or_trainium,
2322
_update_environment_variables,
2423
_is_image_compatible_with_optimization_job,
2524
_extract_speculative_draft_model_provider,
2625
_validate_optimization_inputs,
2726
_extracts_and_validates_speculative_model_source,
27+
_is_s3_uri,
28+
_generate_additional_model_data_sources,
29+
_generate_channel_name,
2830
)
2931

3032
mock_optimization_job_output = {
31-
"OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:"
32-
"optimization-job/modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968",
33+
"OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job/"
34+
"modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691",
3335
"OptimizationJobStatus": "COMPLETED",
34-
"OptimizationJobName": "modelbuilderjob-6b09ffebeb0741b8a28b85623fd9c968",
36+
"OptimizationJobName": "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691",
3537
"ModelSource": {
3638
"S3": {
3739
"S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/"
@@ -46,7 +48,7 @@
4648
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
4749
"SAGEMAKER_PROGRAM": "inference.py",
4850
},
49-
"DeploymentInstanceType": "ml.g5.48xlarge",
51+
"DeploymentInstanceType": "ml.g5.2xlarge",
5052
"OptimizationConfigs": [
5153
{
5254
"ModelQuantizationConfig": {
@@ -55,40 +57,26 @@
5557
}
5658
}
5759
],
58-
"OutputConfig": {
59-
"S3OutputLocation": "s3://dont-delete-ss-jarvis-integ-test-312206380606-us-west-2/"
60-
},
60+
"OutputConfig": {"S3OutputLocation": "s3://quicksilver-model-data/llama-3-8b/quantized-1/"},
6161
"OptimizationOutput": {
6262
"RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124"
6363
},
64-
"RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628",
64+
"RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20240116T151132",
6565
"StoppingCondition": {"MaxRuntimeInSeconds": 36000},
6666
"ResponseMetadata": {
67-
"RequestId": "17ae151f-b51d-4194-8ba9-edbba068c90b",
67+
"RequestId": "a95253d5-c045-4708-8aac-9f0d327515f7",
6868
"HTTPStatusCode": 200,
6969
"HTTPHeaders": {
70-
"x-amzn-requestid": "17ae151f-b51d-4194-8ba9-edbba068c90b",
70+
"x-amzn-requestid": "a95253d5-c045-4708-8aac-9f0d327515f7",
7171
"content-type": "application/x-amz-json-1.1",
72-
"content-length": "1380",
73-
"date": "Thu, 20 Jun 2024 19:25:53 GMT",
72+
"content-length": "1371",
73+
"date": "Fri, 21 Jun 2024 04:27:42 GMT",
7474
},
7575
"RetryAttempts": 0,
7676
},
7777
}
7878

7979

80-
@pytest.mark.parametrize(
81-
"instance, expected",
82-
[
83-
("ml.trn1.2xlarge", True),
84-
("ml.inf2.xlarge", True),
85-
("ml.c7gd.4xlarge", False),
86-
],
87-
)
88-
def test_is_inferentia_or_trainium(instance, expected):
89-
assert _is_inferentia_or_trainium(instance) == expected
90-
91-
9280
@pytest.mark.parametrize(
9381
"image_uri, expected",
9482
[
@@ -124,17 +112,21 @@ def test_generate_optimized_model():
124112
"meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/"
125113
}
126114
}
115+
pysdk_model.env = {"OPTION_QUANTIZE": "awq"}
127116

128117
optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output)
129118

130119
assert (
131120
optimized_model.image_uri
132121
== mock_optimization_job_output["OptimizationOutput"]["RecommendedInferenceImage"]
133122
)
134-
assert optimized_model.env == mock_optimization_job_output["OptimizationEnvironment"]
123+
assert optimized_model.env == {
124+
"OPTION_QUANTIZE": "awq",
125+
**mock_optimization_job_output["OptimizationEnvironment"],
126+
}
135127
assert (
136128
optimized_model.model_data["S3DataSource"]["S3Uri"]
137-
== mock_optimization_job_output["ModelSource"]["S3"]
129+
== mock_optimization_job_output["OutputConfig"]["S3OutputLocation"]
138130
)
139131
assert optimized_model.instance_type == mock_optimization_job_output["DeploymentInstanceType"]
140132
pysdk_model.add_tags.assert_called_once_with(
@@ -209,3 +201,61 @@ def test_extract_speculative_draft_model_s3_uri():
209201
def test_extract_speculative_draft_model_s3_uri_ex():
210202
with pytest.raises(ValueError):
211203
_extracts_and_validates_speculative_model_source({"ModelSource": None})
204+
205+
206+
def test_generate_channel_name():
207+
assert _generate_channel_name(None) is not None
208+
209+
additional_model_data_sources = _generate_additional_model_data_sources(
210+
"s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True
211+
)
212+
213+
assert _generate_channel_name(additional_model_data_sources) == "channel_name"
214+
215+
216+
def test_generate_additional_model_data_sources():
217+
model_source = _generate_additional_model_data_sources(
218+
"s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True
219+
)
220+
221+
assert model_source == [
222+
{
223+
"ChannelName": "channel_name",
224+
"S3DataSource": {
225+
"S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/",
226+
"S3DataType": "S3Prefix",
227+
"CompressionType": "None",
228+
"ModelAccessConfig": {"ACCEPT_EULA": True},
229+
},
230+
}
231+
]
232+
233+
model_source = _generate_additional_model_data_sources(
234+
"s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", False
235+
)
236+
237+
assert model_source == [
238+
{
239+
"ChannelName": "channel_name",
240+
"S3DataSource": {
241+
"S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/",
242+
"S3DataType": "S3Prefix",
243+
"CompressionType": "None",
244+
},
245+
}
246+
]
247+
248+
249+
@pytest.mark.parametrize(
250+
"s3_uri, expected",
251+
[
252+
(
253+
"s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/"
254+
"meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/",
255+
True,
256+
),
257+
("invalid://", False),
258+
],
259+
)
260+
def test_is_s3_uri(s3_uri, expected):
261+
assert _is_s3_uri(s3_uri) == expected

0 commit comments

Comments
 (0)