Skip to content

Commit 70e00fd

Browse files
author
Joseph Zhang
committed
Enable quantization and compilation in the same optimization job via ModelBuilder.
1 parent 19e56f3 commit 70e00fd

File tree

5 files changed

+152
-26
lines changed

5 files changed

+152
-26
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -718,20 +718,23 @@ def _optimize_for_jumpstart(
718718
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
719719
)
720720

721-
is_compilation = (not quantization_config) and (
722-
(compilation_config is not None) or _is_inferentia_or_trainium(instance_type)
721+
is_compilation = (compilation_config is not None) or _is_inferentia_or_trainium(
722+
instance_type
723723
)
724724

725725
pysdk_model_env_vars = dict()
726726
if is_compilation:
727727
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
728728

729+
# optimization_config can contain configs for both quantization and compilation
729730
optimization_config, override_env = _extract_optimization_config_and_env(
730731
quantization_config, compilation_config
731732
)
732-
if not optimization_config and is_compilation:
733+
if (
734+
not optimization_config or not optimization_config.get("ModelCompilationConfig")
735+
) and is_compilation:
733736
override_env = override_env or pysdk_model_env_vars
734-
optimization_config = {
737+
optimization_config["ModelCompilationConfig"] = {
735738
"ModelCompilationConfig": {
736739
"OverrideEnvironment": override_env,
737740
}
@@ -766,7 +769,7 @@ def _optimize_for_jumpstart(
766769
"OptimizationJobName": job_name,
767770
"ModelSource": model_source,
768771
"DeploymentInstanceType": self.instance_type,
769-
"OptimizationConfigs": [optimization_config],
772+
"OptimizationConfigs": [{k: v} for k, v in optimization_config.items()],
770773
"OutputConfig": output_config,
771774
"RoleArn": self.role_arn,
772775
}

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,9 +1235,6 @@ def _model_builder_optimize_wrapper(
12351235
if self.mode != Mode.SAGEMAKER_ENDPOINT:
12361236
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")
12371237

1238-
if quantization_config and compilation_config:
1239-
raise ValueError("Quantization config and compilation config are mutually exclusive.")
1240-
12411238
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12421239
self.instance_type = instance_type or self.instance_type
12431240
self.role_arn = role_arn or self.role_arn
@@ -1345,7 +1342,9 @@ def _optimize_for_hf(
13451342
optimization_config, override_env = _extract_optimization_config_and_env(
13461343
quantization_config, compilation_config
13471344
)
1348-
create_optimization_job_args["OptimizationConfigs"] = [optimization_config]
1345+
create_optimization_job_args["OptimizationConfigs"] = [
1346+
{k: v} for k, v in optimization_config.items()
1347+
]
13491348
self.pysdk_model.env.update(override_env)
13501349

13511350
output_config = {"S3OutputLocation": output_path}

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,25 @@ def _extract_optimization_config_and_env(
271271
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
272272
The optimization config and environment variables.
273273
"""
274+
optimization_config = {}
275+
quantization_override_env = {}
276+
compilation_override_env = {}
277+
274278
if quantization_config:
275-
return {"ModelQuantizationConfig": quantization_config}, quantization_config.get(
276-
"OverrideEnvironment"
277-
)
279+
optimization_config["ModelQuantizationConfig"] = quantization_config
280+
quantization_override_env = quantization_config.get("OverrideEnvironment")
281+
278282
if compilation_config:
279-
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
280-
"OverrideEnvironment"
281-
)
283+
optimization_config["ModelCompilationConfig"] = compilation_config
284+
compilation_override_env = compilation_config.get("OverrideEnvironment")
285+
286+
# Return both dicts and environment variable if either is present
287+
if optimization_config:
288+
return optimization_config, {
289+
**(quantization_override_env or {}),
290+
**(compilation_override_env or {}),
291+
}
292+
282293
return None, None
283294

284295

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,65 @@ def test_optimize_quantize_for_jumpstart(
11981198

11991199
self.assertIsNotNone(out_put)
12001200

1201+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1202+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1203+
def test_optimize_quantize_and_compile_for_jumpstart(
1204+
self,
1205+
mock_serve_settings,
1206+
mock_telemetry,
1207+
):
1208+
mock_sagemaker_session = Mock()
1209+
mock_metadata_config = Mock()
1210+
mock_metadata_config.resolved_config = {
1211+
"supported_inference_instance_types": ["ml.inf2.48xlarge"],
1212+
"hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b",
1213+
}
1214+
1215+
mock_pysdk_model = Mock()
1216+
mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
1217+
mock_pysdk_model.model_data = mock_model_data
1218+
mock_pysdk_model.image_uri = mock_tgi_image_uri
1219+
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
1220+
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
1221+
mock_pysdk_model.config_name = "config_name"
1222+
mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config}
1223+
1224+
sample_input = {
1225+
"inputs": "The diamondback terrapin or simply terrapin is a species "
1226+
"of turtle native to the brackish coastal tidal marshes of the",
1227+
"parameters": {"max_new_tokens": 1024},
1228+
}
1229+
sample_output = [
1230+
{
1231+
"generated_text": "The diamondback terrapin or simply terrapin is a "
1232+
"species of turtle native to the brackish coastal "
1233+
"tidal marshes of the east coast."
1234+
}
1235+
]
1236+
1237+
model_builder = ModelBuilder(
1238+
model="meta-textgeneration-llama-3-70b",
1239+
schema_builder=SchemaBuilder(sample_input, sample_output),
1240+
sagemaker_session=mock_sagemaker_session,
1241+
)
1242+
1243+
model_builder.pysdk_model = mock_pysdk_model
1244+
1245+
out_put = model_builder._optimize_for_jumpstart(
1246+
accept_eula=True,
1247+
quantization_config={
1248+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
1249+
},
1250+
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
1251+
env_vars={
1252+
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
1253+
"OPTION_MAX_ROLLING_BATCH_SIZE": "2",
1254+
},
1255+
output_path="s3://bucket/code/",
1256+
)
1257+
1258+
self.assertIsNotNone(out_put)
1259+
12011260
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12021261
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
12031262
@patch(

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2650,21 +2650,75 @@ def test_optimize_local_mode(self, mock_get_serve_setting):
26502650
),
26512651
)
26522652

2653+
@patch.object(ModelBuilder, "_prepare_for_mode")
26532654
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2654-
def test_optimize_exclusive_args(self, mock_get_serve_setting):
2655-
mock_sagemaker_session = Mock()
2655+
def test_optimize_for_hf_with_both_quantization_and_compilation(
2656+
self,
2657+
mock_get_serve_setting,
2658+
mock_prepare_for_mode,
2659+
):
2660+
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
2661+
{
2662+
"S3DataSource": {
2663+
"CompressionType": "None",
2664+
"S3DataType": "S3Prefix",
2665+
"S3Uri": "s3://bucket/code/code/",
2666+
}
2667+
},
2668+
{"DTYPE": "bfloat16"},
2669+
)
2670+
2671+
mock_pysdk_model = Mock()
2672+
mock_pysdk_model.model_data = None
2673+
mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"}
2674+
26562675
model_builder = ModelBuilder(
2657-
model="meta-textgeneration-llama-3-70b",
2658-
sagemaker_session=mock_sagemaker_session,
2676+
model="meta-llama/Meta-Llama-3-8B-Instruct",
2677+
env_vars={"HF_TOKEN": "token"},
2678+
model_metadata={
2679+
"CUSTOM_MODEL_PATH": "s3://bucket/path/",
2680+
},
2681+
role_arn="role-arn",
2682+
instance_type="ml.g5.2xlarge",
26592683
)
26602684

2661-
self.assertRaisesRegex(
2662-
ValueError,
2663-
"Quantization config and compilation config are mutually exclusive.",
2664-
lambda: model_builder.optimize(
2665-
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2666-
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2667-
),
2685+
model_builder.pysdk_model = mock_pysdk_model
2686+
2687+
out_put = model_builder._optimize_for_hf(
2688+
job_name="job_name-123",
2689+
quantization_config={
2690+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
2691+
},
2692+
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
2693+
output_path="s3://bucket/code/",
2694+
)
2695+
2696+
self.assertEqual(model_builder.env_vars["HF_TOKEN"], "token")
2697+
self.assertEqual(model_builder.role_arn, "role-arn")
2698+
self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge")
2699+
self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq")
2700+
self.assertEqual(model_builder.pysdk_model.env["OPTION_TENSOR_PARALLEL_DEGREE"], "2")
2701+
self.assertEqual(
2702+
out_put,
2703+
{
2704+
"OptimizationJobName": "job_name-123",
2705+
"DeploymentInstanceType": "ml.g5.2xlarge",
2706+
"RoleArn": "role-arn",
2707+
"ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}},
2708+
"OptimizationConfigs": [
2709+
{
2710+
"ModelQuantizationConfig": {
2711+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}
2712+
}
2713+
},
2714+
{
2715+
"ModelCompilationConfig": {
2716+
"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}
2717+
}
2718+
},
2719+
],
2720+
"OutputConfig": {"S3OutputLocation": "s3://bucket/code/"},
2721+
},
26682722
)
26692723

26702724
@patch.object(ModelBuilder, "_prepare_for_mode")

0 commit comments

Comments
 (0)