diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index e5e850b885..a7a518105c 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1433,15 +1433,15 @@ def _model_builder_optimize_wrapper( # HF Model ID format = "meta-llama/Meta-Llama-3.1-8B" # JS Model ID format = "meta-textgeneration-llama-3-1-8b" - llama_3_1_keywords = ["llama-3.1", "llama-3-1"] - is_llama_3_1 = self.model and any( - keyword in self.model.lower() for keyword in llama_3_1_keywords + is_llama_3_plus = self.model and bool( + re.search(r"llama-3[\.\-][1-9]\d*", self.model.lower()) ) if is_gpu_instance and self.model and self.is_compiled: - if is_llama_3_1: + if is_llama_3_plus: raise ValueError( - "Compilation is not supported for Llama-3.1 with a GPU instance." + "Compilation is not supported for models greater " + "than Llama-3.0 with a GPU instance." ) if speculative_decoding_config: raise ValueError( diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 7355fe4f38..1e20bf1cf3 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -3270,7 +3270,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( mock_pysdk_model = Mock() mock_pysdk_model.model_data = None - mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-1-8B-Instruct"} + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-2-8B-Instruct"} sample_input = {"inputs": "dummy prompt", "parameters": {}} @@ -3279,7 +3279,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( dummy_schema_builder = SchemaBuilder(sample_input, sample_output) model_builder = ModelBuilder( - model="meta-llama/Meta-Llama-3-1-8B-Instruct", + model="meta-llama/Meta-Llama-3-2-8B-Instruct", schema_builder=dummy_schema_builder, env_vars={"HF_TOKEN": "token"}, model_metadata={ @@ -3293,7 +3293,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( self.assertRaisesRegex( ValueError, - "Compilation is not supported for Llama-3.1 with a GPU instance.", + "Compilation is not supported for models greater than Llama-3.0 with a GPU instance.", lambda: model_builder.optimize( job_name="job_name-123", instance_type="ml.g5.24xlarge",