Skip to content

Commit f121eb0

Browse files
author
Joseph Zhang
committed
Add TRTLLM compilation + speculative decoding validation.
1 parent bf706ad commit f121eb0

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,7 @@ def _model_builder_optimize_wrapper(
12831283
# TRTLLM is used by Neo if the following are provided:
12841284
# 1) a GPU instance type
12851285
# 2) compilation config
1286-
gpu_instance_families = ["g5", "g6", "p4d", "p5"]
1286+
gpu_instance_families = ["g5", "g6", "p4d", "p4de", "p5"]
12871287
is_gpu_instance = optimization_instance_type and any(
12881288
gpu_instance_family in optimization_instance_type
12891289
for gpu_instance_family in gpu_instance_families
@@ -1296,8 +1296,16 @@ def _model_builder_optimize_wrapper(
12961296
keyword in self.model.lower() for keyword in llama_3_1_keywords
12971297
)
12981298

1299-
if is_gpu_instance and self.model and is_llama_3_1 and self.is_compiled:
1300-
raise ValueError("Compilation is not supported for Llama-3.1 with a GPU instance.")
1299+
if is_gpu_instance and self.model and self.is_compiled:
1300+
if is_llama_3_1:
1301+
raise ValueError(
1302+
"Compilation is not supported for Llama-3.1 with a GPU instance."
1303+
)
1304+
if speculative_decoding_config:
1305+
raise ValueError(
1306+
"Compilation is not supported with speculative decoding with "
1307+
"a GPU instance."
1308+
)
13011309

13021310
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
13031311
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2891,3 +2891,58 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
28912891
output_path="s3://bucket/code/",
28922892
),
28932893
)
2894+
2895+
@patch.object(ModelBuilder, "_prepare_for_mode")
2896+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2897+
def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding(
2898+
self,
2899+
mock_get_serve_setting,
2900+
mock_prepare_for_mode,
2901+
):
2902+
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
2903+
{
2904+
"S3DataSource": {
2905+
"CompressionType": "None",
2906+
"S3DataType": "S3Prefix",
2907+
"S3Uri": "s3://bucket/code/code/",
2908+
}
2909+
},
2910+
{"DTYPE": "bfloat16"},
2911+
)
2912+
2913+
mock_pysdk_model = Mock()
2914+
mock_pysdk_model.model_data = None
2915+
mock_pysdk_model.env = {"HF_MODEL_ID": "modelid"}
2916+
2917+
sample_input = {"inputs": "dummy prompt", "parameters": {}}
2918+
2919+
sample_output = [{"generated_text": "dummy response"}]
2920+
2921+
dummy_schema_builder = SchemaBuilder(sample_input, sample_output)
2922+
2923+
model_builder = ModelBuilder(
2924+
model="modelid",
2925+
schema_builder=dummy_schema_builder,
2926+
env_vars={"HF_TOKEN": "token"},
2927+
model_metadata={
2928+
"CUSTOM_MODEL_PATH": "s3://bucket/path/",
2929+
},
2930+
role_arn="role-arn",
2931+
instance_type="ml.g5.2xlarge",
2932+
)
2933+
2934+
model_builder.pysdk_model = mock_pysdk_model
2935+
2936+
self.assertRaisesRegex(
2937+
ValueError,
2938+
"Compilation is not supported with speculative decoding with a GPU instance.",
2939+
lambda: model_builder.optimize(
2940+
job_name="job_name-123",
2941+
speculative_decoding_config={
2942+
"ModelProvider": "custom",
2943+
"ModelSource": "s3://data-source",
2944+
},
2945+
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
2946+
output_path="s3://bucket/code/",
2947+
),
2948+
)

0 commit comments

Comments
 (0)