Skip to content

Commit 828ad60

Browse files
author
Joseph Zhang
committed
Add validations to block compilation jobs using TRTLLM an Llama-3.1.
1 parent 54e995f commit 828ad60

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,28 @@ def _model_builder_optimize_wrapper(
12761276
)
12771277

12781278
if input_args:
1279+
optimization_instance_type = input_args["DeploymentInstanceType"]
1280+
1281+
# Compilation using TRTLLM and Llama-3.1 is currently not supported.
1282+
# TRTLLM is used by Neo if the following are provided:
1283+
# 1) a GPU instance type
1284+
# 2) compilation config
1285+
gpu_instance_families = ["g4", "g5", "p4d"]
1286+
is_gpu_instance = optimization_instance_type and any(
1287+
gpu_instance_family in optimization_instance_type
1288+
for gpu_instance_family in gpu_instance_families
1289+
)
1290+
1291+
# HF Model ID format = "meta-llama/Meta-Llama-3.1-8B"
1292+
# JS Model ID format = "meta-textgeneration-llama-3-1-8b"
1293+
llama_3_1_keywords = ["llama-3.1", "llama-3-1"]
1294+
is_llama_3_1 = self.model and any(
1295+
keyword in self.model.lower() for keyword in llama_3_1_keywords
1296+
)
1297+
1298+
if is_gpu_instance and self.model and is_llama_3_1 and self.is_compiled:
1299+
raise ValueError("Compilation is not supported for Llama-3.1 with a GPU instance.")
1300+
12791301
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
12801302
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
12811303
return _generate_optimized_model(self.pysdk_model, job_status)

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2840,3 +2840,54 @@ def test_optimize_for_hf_without_custom_s3_path(
28402840
"OutputConfig": {"S3OutputLocation": "s3://bucket/code/"},
28412841
},
28422842
)
2843+
2844+
@patch.object(ModelBuilder, "_prepare_for_mode")
2845+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2846+
def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
2847+
self,
2848+
mock_get_serve_setting,
2849+
mock_prepare_for_mode,
2850+
):
2851+
mock_prepare_for_mode.side_effect = lambda *args, **kwargs: (
2852+
{
2853+
"S3DataSource": {
2854+
"CompressionType": "None",
2855+
"S3DataType": "S3Prefix",
2856+
"S3Uri": "s3://bucket/code/code/",
2857+
}
2858+
},
2859+
{"DTYPE": "bfloat16"},
2860+
)
2861+
2862+
mock_pysdk_model = Mock()
2863+
mock_pysdk_model.model_data = None
2864+
mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-1-8B-Instruct"}
2865+
2866+
sample_input = {"inputs": "dummy prompt", "parameters": {}}
2867+
2868+
sample_output = [{"generated_text": "dummy response"}]
2869+
2870+
dummy_schema_builder = SchemaBuilder(sample_input, sample_output)
2871+
2872+
model_builder = ModelBuilder(
2873+
model="meta-llama/Meta-Llama-3-1-8B-Instruct",
2874+
schema_builder=dummy_schema_builder,
2875+
env_vars={"HF_TOKEN": "token"},
2876+
model_metadata={
2877+
"CUSTOM_MODEL_PATH": "s3://bucket/path/",
2878+
},
2879+
role_arn="role-arn",
2880+
instance_type="ml.g5.2xlarge",
2881+
)
2882+
2883+
model_builder.pysdk_model = mock_pysdk_model
2884+
2885+
self.assertRaisesRegex(
2886+
ValueError,
2887+
"Compilation is not supported for Llama-3.1 with a GPU instance.",
2888+
lambda: model_builder.optimize(
2889+
job_name="job_name-123",
2890+
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
2891+
output_path="s3://bucket/code/",
2892+
),
2893+
)

0 commit comments

Comments
 (0)