Skip to content

Commit 0151209

Browse files
grenmesterJacky Lee
andauthored
Add optimize to ModelBuilder (#1468)
* Add optimize to ModelBuilder * Add polling for job completion * fix UTs --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent c4529e3 commit 0151209

File tree

4 files changed

+296
-12
lines changed

4 files changed

+296
-12
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@
6262
from sagemaker.serve.utils import task
6363
from sagemaker.serve.utils.exceptions import TaskNotFoundException
6464
from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
65+
from sagemaker.serve.utils.optimize_utils import (
66+
_is_compatible_with_compilation,
67+
_poll_optimization_job,
68+
)
6569
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
6670
from sagemaker.serve.utils.hardware_detector import (
6771
_get_gpu_info,
@@ -83,6 +87,7 @@
8387
from sagemaker.serve.validations.check_image_and_hardware_type import (
8488
validate_image_uri_and_hardware,
8589
)
90+
from sagemaker.utils import Tags
8691
from sagemaker.workflow.entities import PipelineVariable
8792
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
8893

@@ -804,8 +809,15 @@ def save(
804809
This function is available for models served by DJL serving.
805810
806811
Args:
807-
save_path (Optional[str]): The path where you want to save resources.
808-
s3_path (Optional[str]): The path where you want to upload resources.
812+
save_path (Optional[str]): The path where you want to save resources. Defaults to
813+
``None``.
814+
s3_path (Optional[str]): The path where you want to upload resources. Defaults to
815+
``None``.
816+
sagemaker_session (Optional[Session]): Session object which manages interactions
817+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
818+
function creates one using the default AWS configuration chain. Defaults to
819+
``None``.
820+
role_arn (Optional[str]): The IAM role arn. Defaults to ``None``.
809821
"""
810822
self.sagemaker_session = sagemaker_session or Session()
811823

@@ -915,3 +927,129 @@ def _try_fetch_gpu_info(self):
915927
raise ValueError(
916928
f"Unable to determine single GPU size for instance: [{self.instance_type}]"
917929
)
930+
931+
def optimize(self, *args, **kwargs) -> Type[Model]:
932+
"""Runs a model optimization job.
933+
934+
Args:
935+
instance_type (str): Target deployment instance type that the model is optimized for.
936+
output_path (str): Specifies where to store the compiled/quantized model.
937+
role (Optional[str]): Execution role. Defaults to ``None``.
938+
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
939+
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
940+
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
941+
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
942+
env_vars (Optional[Dict]): Additional environment variables to run the optimization
943+
container. Defaults to ``None``.
944+
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
945+
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
946+
to S3. Defaults to ``None``.
947+
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
948+
``None``.
949+
sagemaker_session (Optional[Session]): Session object which manages interactions
950+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
951+
function creates one using the default AWS configuration chain.
952+
953+
Returns:
954+
Type[Model]: A deployable ``Model`` object.
955+
"""
956+
# need to get telemetry_opt_out info before telemetry decorator is called
957+
self.serve_settings = self._get_serve_setting()
958+
959+
return self._model_builder_optimize_wrapper(*args, **kwargs)
960+
961+
@_capture_telemetry("optimize")
962+
def _model_builder_optimize_wrapper(
963+
self,
964+
instance_type: str,
965+
output_path: str,
966+
role: Optional[str] = None,
967+
tags: Optional[Tags] = None,
968+
job_name: Optional[str] = None,
969+
quantization_config: Optional[Dict] = None,
970+
compilation_config: Optional[Dict] = None,
971+
env_vars: Optional[Dict] = None,
972+
vpc_config: Optional[Dict] = None,
973+
kms_key: Optional[str] = None,
974+
max_runtime_in_sec: Optional[int] = None,
975+
sagemaker_session: Optional[Session] = None,
976+
) -> Type[Model]:
977+
"""Runs a model optimization job.
978+
979+
Args:
980+
instance_type (str): Target deployment instance type that the model is optimized for.
981+
output_path (str): Specifies where to store the compiled/quantized model.
982+
role (Optional[str]): Execution role. Defaults to ``None``.
983+
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
984+
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
985+
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
986+
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
987+
env_vars (Optional[Dict]): Additional environment variables to run the optimization
988+
container. Defaults to ``None``.
989+
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
990+
kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading
991+
to S3. Defaults to ``None``.
992+
max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to
993+
``None``.
994+
sagemaker_session (Optional[Session]): Session object which manages interactions
995+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
996+
function creates one using the default AWS configuration chain.
997+
998+
Returns:
999+
Type[Model]: A deployable ``Model`` object.
1000+
"""
1001+
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
1002+
1003+
# TODO: inject actual model source location based on different scenarios
1004+
model_source = {"S3": {"S3Uri": self.model_path, "ModelAccessConfig": {"AcceptEula": True}}}
1005+
1006+
optimization_configs = []
1007+
if quantization_config:
1008+
optimization_configs.append({"ModelQuantizationConfig": quantization_config})
1009+
if compilation_config:
1010+
if _is_compatible_with_compilation(instance_type):
1011+
optimization_configs.append({"ModelCompilationConfig": compilation_config})
1012+
else:
1013+
logger.warning(
1014+
"Model compilation is currently only supported for Inferentia and Trainium"
1015+
"instances, ignoring `compilation_config'."
1016+
)
1017+
1018+
output_config = {"S3OutputLocation": output_path}
1019+
if kms_key:
1020+
output_config["KmsKeyId"] = kms_key
1021+
1022+
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
1023+
create_optimization_job_args = {
1024+
"OptimizationJobName": job_name,
1025+
"ModelSource": model_source,
1026+
"DeploymentInstanceType": instance_type,
1027+
"OptimizationConfigs": optimization_configs,
1028+
"OutputConfig": output_config,
1029+
"RoleArn": role or self.role_arn,
1030+
}
1031+
1032+
if env_vars:
1033+
create_optimization_job_args["OptimizationEnvironment"] = env_vars
1034+
1035+
if max_runtime_in_sec:
1036+
create_optimization_job_args["StoppingCondition"] = {
1037+
"MaxRuntimeInSeconds": max_runtime_in_sec
1038+
}
1039+
1040+
# TODO: tag injection if it is a JumpStart model
1041+
if tags:
1042+
create_optimization_job_args["Tags"] = tags
1043+
1044+
if vpc_config:
1045+
create_optimization_job_args["VpcConfig"] = vpc_config
1046+
1047+
response = self.sagemaker_session.sagemaker_client.create_optimization_job(
1048+
**create_optimization_job_args
1049+
)
1050+
1051+
if not _poll_optimization_job(job_name, self.sagemaker_session):
1052+
raise Exception("Optimization job timed out.")
1053+
1054+
# TODO: return model created by optimization job
1055+
return response
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds the util functions used for the optimize function"""
14+
from __future__ import absolute_import
15+
16+
import time
17+
import logging
18+
19+
from sagemaker import Session
20+
21+
# TODO: determine how long optimization jobs take
22+
OPTIMIZE_POLLER_MAX_TIMEOUT_SECS = 300
23+
OPTIMIZE_POLLER_INTERVAL_SECS = 30
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
def _is_compatible_with_compilation(instance_type: str) -> bool:
29+
"""Checks whether an instance is compatible with compilation.
30+
31+
Args:
32+
instance_type (str): The instance type used for the compilation job.
33+
34+
Returns:
35+
bool: Whether the given instance type is compatible with compilation.
36+
"""
37+
return instance_type.startswith("ml.inf") or instance_type.startswith("ml.trn")
38+
39+
40+
def _poll_optimization_job(job_name: str, sagemaker_session: Session) -> bool:
41+
"""Polls optimization job status until success.
42+
43+
Args:
44+
job_name (str): The name of the optimization job.
45+
sagemaker_session (Session): Session object which manages interactions
46+
with Amazon SageMaker APIs and any other AWS services needed.
47+
48+
Returns:
49+
bool: Whether the optimization job was successful.
50+
"""
51+
logger.info("Polling status of optimization job %s", job_name)
52+
start_time = time.time()
53+
while time.time() - start_time < OPTIMIZE_POLLER_MAX_TIMEOUT_SECS:
54+
result = sagemaker_session.sagemaker_client.describe_optimization_job(job_name)
55+
# TODO: use correct condition to determine whether optimization job is complete
56+
if result is not None:
57+
return result
58+
time.sleep(OPTIMIZE_POLLER_INTERVAL_SECS)

src/sagemaker/serve/utils/telemetry_logger.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,22 @@ def wrapper(self, *args, **kwargs):
8080
response = None
8181
caught_ex = None
8282

83-
image_uri_tail = self.image_uri.split("/")[1]
84-
image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri)
85-
extra = (
86-
f"{func_name}"
87-
f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}"
88-
f"&x-imageTag={image_uri_tail}"
89-
f"&x-sdkVersion={SDK_VERSION}"
90-
f"&x-defaultImageUsage={image_uri_option}"
91-
)
83+
extra = f"{func_name}"
84+
85+
if self.model_server:
86+
extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}"
87+
88+
if self.image_uri:
89+
image_uri_tail = self.image_uri.split("/")[1]
90+
image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri)
91+
92+
if self.image_uri:
93+
extra += f"&x-imageTag={image_uri_tail}"
94+
95+
extra += f"&x-sdkVersion={SDK_VERSION}"
96+
97+
if self.image_uri:
98+
extra += f"&x-defaultImageUsage={image_uri_option}"
9299

93100
if self.model_server == ModelServer.DJL_SERVING or self.model_server == ModelServer.TGI:
94101
extra += f"&x-modelName={self.model}"

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
mock_image_uri = "abcd/efghijk"
4646
mock_1p_dlc_image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com"
47-
mock_role_arn = "sample role arn"
47+
mock_role_arn = "arn:aws:iam::123456789012:role/SageMakerRole"
4848
mock_s3_model_data_url = "sample s3 data url"
4949
mock_secret_key = "mock_secret_key"
5050
mock_instance_type = "mock instance type"
@@ -2257,3 +2257,84 @@ def test_build_tensorflow_serving_non_mlflow_case(
22572257
mock_role_arn,
22582258
mock_session,
22592259
)
2260+
2261+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2262+
@patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
2263+
def test_optimize(self, mock_send_telemetry, mock_get_serve_setting):
2264+
mock_sagemaker_session = Mock()
2265+
2266+
mock_settings = Mock()
2267+
mock_settings.telemetry_opt_out = False
2268+
mock_get_serve_setting.return_value = mock_settings
2269+
2270+
builder = ModelBuilder(
2271+
model_path=MODEL_PATH,
2272+
schema_builder=schema_builder,
2273+
model=mock_fw_model,
2274+
sagemaker_session=mock_sagemaker_session,
2275+
)
2276+
2277+
job_name = "my-optimization-job"
2278+
instance_type = "ml.inf1.xlarge"
2279+
output_path = "s3://my-bucket/output"
2280+
quantization_config = {
2281+
"Image": "quantization-image-uri",
2282+
"OverrideEnvironment": {"ENV_VAR": "value"},
2283+
}
2284+
compilation_config = {
2285+
"Image": "compilation-image-uri",
2286+
"OverrideEnvironment": {"ENV_VAR": "value"},
2287+
}
2288+
env_vars = {"Var1": "value", "Var2": "value"}
2289+
kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id"
2290+
max_runtime_in_sec = 3600
2291+
tags = [
2292+
{"Key": "Project", "Value": "my-project"},
2293+
{"Key": "Environment", "Value": "production"},
2294+
]
2295+
vpc_config = {
2296+
"SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"],
2297+
"Subnets": ["subnet-01234567", "subnet-89abcdef"],
2298+
}
2299+
2300+
expected_create_optimization_job_args = {
2301+
"ModelSource": {"S3": {"S3Uri": MODEL_PATH, "ModelAccessConfig": {"AcceptEula": True}}},
2302+
"DeploymentInstanceType": instance_type,
2303+
"OptimizationEnvironment": env_vars,
2304+
"OptimizationConfigs": [
2305+
{"ModelQuantizationConfig": quantization_config},
2306+
{"ModelCompilationConfig": compilation_config},
2307+
],
2308+
"OutputConfig": {"S3OutputLocation": output_path, "KmsKeyId": kms_key},
2309+
"RoleArn": mock_role_arn,
2310+
"OptimizationJobName": job_name,
2311+
"StoppingCondition": {"MaxRuntimeInSeconds": max_runtime_in_sec},
2312+
"Tags": [
2313+
{"Key": "Project", "Value": "my-project"},
2314+
{"Key": "Environment", "Value": "production"},
2315+
],
2316+
"VpcConfig": vpc_config,
2317+
}
2318+
2319+
mock_sagemaker_session.sagemaker_client.create_optimization_job.return_value = {
2320+
"OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job"
2321+
}
2322+
2323+
builder.optimize(
2324+
instance_type=instance_type,
2325+
output_path=output_path,
2326+
role=mock_role_arn,
2327+
job_name=job_name,
2328+
quantization_config=quantization_config,
2329+
compilation_config=compilation_config,
2330+
env_vars=env_vars,
2331+
kms_key=kms_key,
2332+
max_runtime_in_sec=max_runtime_in_sec,
2333+
tags=tags,
2334+
vpc_config=vpc_config,
2335+
)
2336+
2337+
mock_send_telemetry.assert_called_once()
2338+
mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with(
2339+
**expected_create_optimization_job_args
2340+
)

0 commit comments

Comments
 (0)