Skip to content

Commit 7ec16e6

Browse files
author
Joseph Zhang
committed
Enable quantization and compilation in the same optimization job via ModelBuilder and add validations to block compilation jobs using TRTLLM an Llama-3.1.
1 parent 3d8ffb8 commit 7ec16e6

File tree

6 files changed

+370
-45
lines changed

6 files changed

+370
-45
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -718,24 +718,36 @@ def _optimize_for_jumpstart(
718718
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
719719
)
720720

721-
is_compilation = (not quantization_config) and (
722-
(compilation_config is not None) or _is_inferentia_or_trainium(instance_type)
721+
is_compilation = (compilation_config is not None) or _is_inferentia_or_trainium(
722+
instance_type
723723
)
724724

725725
pysdk_model_env_vars = dict()
726726
if is_compilation:
727727
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
728728

729-
optimization_config, override_env = _extract_optimization_config_and_env(
730-
quantization_config, compilation_config
729+
# optimization_config can contain configs for both quantization and compilation
730+
optimization_config, quantization_override_env, compilation_override_env = (
731+
_extract_optimization_config_and_env(quantization_config, compilation_config)
731732
)
732-
if not optimization_config and is_compilation:
733-
override_env = override_env or pysdk_model_env_vars
734-
optimization_config = {
735-
"ModelCompilationConfig": {
736-
"OverrideEnvironment": override_env,
737-
}
738-
}
733+
if (
734+
not optimization_config or not optimization_config.get("ModelCompilationConfig")
735+
) and is_compilation:
736+
# Ensure optimization_config exists
737+
if not optimization_config:
738+
optimization_config = {}
739+
740+
# Fallback to default if override_env is None or empty
741+
if not compilation_override_env:
742+
compilation_override_env = pysdk_model_env_vars
743+
744+
# Update optimization_config with ModelCompilationConfig
745+
override_compilation_config = (
746+
{"OverrideEnvironment": compilation_override_env}
747+
if compilation_override_env
748+
else {}
749+
)
750+
optimization_config["ModelCompilationConfig"] = override_compilation_config
739751

740752
if speculative_decoding_config:
741753
self._set_additional_model_source(speculative_decoding_config)
@@ -766,7 +778,7 @@ def _optimize_for_jumpstart(
766778
"OptimizationJobName": job_name,
767779
"ModelSource": model_source,
768780
"DeploymentInstanceType": self.instance_type,
769-
"OptimizationConfigs": [optimization_config],
781+
"OptimizationConfigs": [{k: v} for k, v in optimization_config.items()],
770782
"OutputConfig": output_config,
771783
"RoleArn": self.role_arn,
772784
}
@@ -789,7 +801,13 @@ def _optimize_for_jumpstart(
789801
"AcceptEula": True
790802
}
791803

792-
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
804+
optimization_env_vars = _update_environment_variables(
805+
optimization_env_vars,
806+
{
807+
**(quantization_override_env or {}),
808+
**(compilation_override_env or {}),
809+
},
810+
)
793811
if optimization_env_vars:
794812
self.pysdk_model.env.update(optimization_env_vars)
795813
if quantization_config or is_compilation:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,9 +1235,6 @@ def _model_builder_optimize_wrapper(
12351235
if self.mode != Mode.SAGEMAKER_ENDPOINT:
12361236
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")
12371237

1238-
if quantization_config and compilation_config:
1239-
raise ValueError("Quantization config and compilation config are mutually exclusive.")
1240-
12411238
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12421239
self.instance_type = instance_type or self.instance_type
12431240
self.role_arn = role_arn or self.role_arn
@@ -1279,6 +1276,28 @@ def _model_builder_optimize_wrapper(
12791276
)
12801277

12811278
if input_args:
1279+
optimization_instance_type = input_args["DeploymentInstanceType"]
1280+
1281+
# Compilation using TRTLLM and Llama-3.1 is currently not supported.
1282+
# TRTLLM is used by Neo if the following are provided:
1283+
# 1) a GPU instance type
1284+
# 2) compilation config
1285+
gpu_instance_families = ["g4", "g5", "p4d"]
1286+
is_gpu_instance = optimization_instance_type and any(
1287+
gpu_instance_family in optimization_instance_type
1288+
for gpu_instance_family in gpu_instance_families
1289+
)
1290+
1291+
# HF Model ID format = "meta-llama/Meta-Llama-3.1-8B"
1292+
# JS Model ID format = "meta-textgeneration-llama-3-1-8b"
1293+
llama_3_1_keywords = ["llama-3.1", "llama-3-1"]
1294+
is_llama_3_1 = self.model and any(
1295+
keyword in self.model.lower() for keyword in llama_3_1_keywords
1296+
)
1297+
1298+
if is_gpu_instance and self.model and is_llama_3_1 and self.is_compiled:
1299+
raise ValueError("Compilation is not supported for Llama-3.1 with a GPU instance.")
1300+
12821301
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
12831302
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
12841303
return _generate_optimized_model(self.pysdk_model, job_status)
@@ -1342,11 +1361,18 @@ def _optimize_for_hf(
13421361
model_source = _generate_model_source(self.pysdk_model.model_data, False)
13431362
create_optimization_job_args["ModelSource"] = model_source
13441363

1345-
optimization_config, override_env = _extract_optimization_config_and_env(
1346-
quantization_config, compilation_config
1364+
optimization_config, quantization_override_env, compilation_override_env = (
1365+
_extract_optimization_config_and_env(quantization_config, compilation_config)
1366+
)
1367+
create_optimization_job_args["OptimizationConfigs"] = [
1368+
{k: v} for k, v in optimization_config.items()
1369+
]
1370+
self.pysdk_model.env.update(
1371+
{
1372+
**(quantization_override_env or {}),
1373+
**(compilation_override_env or {}),
1374+
}
13471375
)
1348-
create_optimization_job_args["OptimizationConfigs"] = [optimization_config]
1349-
self.pysdk_model.env.update(override_env)
13501376

13511377
output_config = {"S3OutputLocation": output_path}
13521378
if kms_key:

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
260260

261261
def _extract_optimization_config_and_env(
262262
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
263-
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
263+
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
264264
"""Extracts optimization config and environment variables.
265265
266266
Args:
@@ -271,15 +271,25 @@ def _extract_optimization_config_and_env(
271271
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
272272
The optimization config and environment variables.
273273
"""
274-
if quantization_config:
275-
return {"ModelQuantizationConfig": quantization_config}, quantization_config.get(
276-
"OverrideEnvironment"
277-
)
278-
if compilation_config:
279-
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
280-
"OverrideEnvironment"
281-
)
282-
return None, None
274+
optimization_config = {}
275+
quantization_override_env = (
276+
quantization_config.get("OverrideEnvironment", {}) if quantization_config else None
277+
)
278+
compilation_override_env = (
279+
compilation_config.get("OverrideEnvironment", {}) if compilation_config else None
280+
)
281+
282+
if quantization_config is not None:
283+
optimization_config["ModelQuantizationConfig"] = quantization_config
284+
285+
if compilation_config is not None:
286+
optimization_config["ModelCompilationConfig"] = compilation_config
287+
288+
# Return both dicts and environment variable if either is present
289+
if optimization_config:
290+
return optimization_config, quantization_override_env, compilation_override_env
291+
292+
return None, None, None
283293

284294

285295
def _custom_speculative_decoding(

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,65 @@ def test_optimize_quantize_for_jumpstart(
11981198

11991199
self.assertIsNotNone(out_put)
12001200

1201+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1202+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1203+
def test_optimize_quantize_and_compile_for_jumpstart(
1204+
self,
1205+
mock_serve_settings,
1206+
mock_telemetry,
1207+
):
1208+
mock_sagemaker_session = Mock()
1209+
mock_metadata_config = Mock()
1210+
mock_metadata_config.resolved_config = {
1211+
"supported_inference_instance_types": ["ml.inf2.48xlarge"],
1212+
"hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b",
1213+
}
1214+
1215+
mock_pysdk_model = Mock()
1216+
mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
1217+
mock_pysdk_model.model_data = mock_model_data
1218+
mock_pysdk_model.image_uri = mock_tgi_image_uri
1219+
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
1220+
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
1221+
mock_pysdk_model.config_name = "config_name"
1222+
mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config}
1223+
1224+
sample_input = {
1225+
"inputs": "The diamondback terrapin or simply terrapin is a species "
1226+
"of turtle native to the brackish coastal tidal marshes of the",
1227+
"parameters": {"max_new_tokens": 1024},
1228+
}
1229+
sample_output = [
1230+
{
1231+
"generated_text": "The diamondback terrapin or simply terrapin is a "
1232+
"species of turtle native to the brackish coastal "
1233+
"tidal marshes of the east coast."
1234+
}
1235+
]
1236+
1237+
model_builder = ModelBuilder(
1238+
model="meta-textgeneration-llama-3-70b",
1239+
schema_builder=SchemaBuilder(sample_input, sample_output),
1240+
sagemaker_session=mock_sagemaker_session,
1241+
)
1242+
1243+
model_builder.pysdk_model = mock_pysdk_model
1244+
1245+
out_put = model_builder._optimize_for_jumpstart(
1246+
accept_eula=True,
1247+
quantization_config={
1248+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
1249+
},
1250+
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
1251+
env_vars={
1252+
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
1253+
"OPTION_MAX_ROLLING_BATCH_SIZE": "2",
1254+
},
1255+
output_path="s3://bucket/code/",
1256+
)
1257+
1258+
self.assertIsNotNone(out_put)
1259+
12011260
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12021261
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
12031262
@patch(
@@ -1383,3 +1442,103 @@ def test_optimize_compile_for_jumpstart_with_neuron_env(
13831442
self.assertEqual(optimized_model.env["OPTION_ROLLING_BATCH"], "auto")
13841443
self.assertEqual(optimized_model.env["OPTION_MAX_ROLLING_BATCH_SIZE"], "4")
13851444
self.assertEqual(optimized_model.env["OPTION_NEURON_OPTIMIZE_LEVEL"], "2")
1445+
1446+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1447+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1448+
@patch(
1449+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model",
1450+
return_value=True,
1451+
)
1452+
@patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel")
1453+
@patch(
1454+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
1455+
return_value=True,
1456+
)
1457+
@patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model")
1458+
@patch(
1459+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
1460+
return_value=({"model_type": "t5", "n_head": 71}, True),
1461+
)
1462+
def test_optimize_compile_for_jumpstart_without_compilation_config(
1463+
self,
1464+
mock_prepare_for_tgi,
1465+
mock_pre_trained_model,
1466+
mock_is_jumpstart_model,
1467+
mock_js_model,
1468+
mock_is_gated_model,
1469+
mock_serve_settings,
1470+
mock_telemetry,
1471+
):
1472+
mock_sagemaker_session = Mock()
1473+
mock_metadata_config = Mock()
1474+
mock_sagemaker_session.wait_for_optimization_job.side_effect = (
1475+
lambda *args: mock_optimization_job_response
1476+
)
1477+
1478+
mock_metadata_config.resolved_config = {
1479+
"supported_inference_instance_types": ["ml.inf2.48xlarge"],
1480+
"hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b",
1481+
}
1482+
1483+
mock_js_model.return_value = MagicMock()
1484+
mock_js_model.return_value.env = {
1485+
"SAGEMAKER_PROGRAM": "inference.py",
1486+
"ENDPOINT_SERVER_TIMEOUT": "3600",
1487+
"MODEL_CACHE_ROOT": "/opt/ml/model",
1488+
"SAGEMAKER_ENV": "1",
1489+
"HF_MODEL_ID": "/opt/ml/model",
1490+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
1491+
}
1492+
1493+
mock_pre_trained_model.return_value = MagicMock()
1494+
mock_pre_trained_model.return_value.env = dict()
1495+
mock_pre_trained_model.return_value.config_name = "config_name"
1496+
mock_pre_trained_model.return_value.model_data = mock_model_data
1497+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
1498+
mock_pre_trained_model.return_value.list_deployment_configs.return_value = (
1499+
DEPLOYMENT_CONFIGS
1500+
)
1501+
mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0]
1502+
mock_pre_trained_model.return_value._metadata_configs = {
1503+
"config_name": mock_metadata_config
1504+
}
1505+
1506+
sample_input = {
1507+
"inputs": "The diamondback terrapin or simply terrapin is a species "
1508+
"of turtle native to the brackish coastal tidal marshes of the",
1509+
"parameters": {"max_new_tokens": 1024},
1510+
}
1511+
sample_output = [
1512+
{
1513+
"generated_text": "The diamondback terrapin or simply terrapin is a "
1514+
"species of turtle native to the brackish coastal "
1515+
"tidal marshes of the east coast."
1516+
}
1517+
]
1518+
1519+
model_builder = ModelBuilder(
1520+
model="meta-textgeneration-llama-3-70b",
1521+
schema_builder=SchemaBuilder(sample_input, sample_output),
1522+
sagemaker_session=mock_sagemaker_session,
1523+
)
1524+
1525+
optimized_model = model_builder.optimize(
1526+
accept_eula=True,
1527+
instance_type="ml.inf2.24xlarge",
1528+
output_path="s3://bucket/code/",
1529+
)
1530+
1531+
self.assertEqual(
1532+
optimized_model.image_uri,
1533+
mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"],
1534+
)
1535+
self.assertEqual(
1536+
optimized_model.model_data["S3DataSource"]["S3Uri"],
1537+
mock_optimization_job_response["OutputConfig"]["S3OutputLocation"],
1538+
)
1539+
self.assertEqual(optimized_model.env["SAGEMAKER_PROGRAM"], "inference.py")
1540+
self.assertEqual(optimized_model.env["ENDPOINT_SERVER_TIMEOUT"], "3600")
1541+
self.assertEqual(optimized_model.env["MODEL_CACHE_ROOT"], "/opt/ml/model")
1542+
self.assertEqual(optimized_model.env["SAGEMAKER_ENV"], "1")
1543+
self.assertEqual(optimized_model.env["HF_MODEL_ID"], "/opt/ml/model")
1544+
self.assertEqual(optimized_model.env["SAGEMAKER_MODEL_SERVER_WORKERS"], "1")

0 commit comments

Comments
 (0)