From 7ec16e6df03865831a1e484511a0efb7a5ada88b Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Tue, 17 Sep 2024 17:36:54 -0700 Subject: [PATCH 01/25] Enable quantization and compilation in the same optimization job via ModelBuilder and add validations to block compilation jobs using TRTLLM an Llama-3.1. --- .../serve/builder/jumpstart_builder.py | 44 +++-- src/sagemaker/serve/builder/model_builder.py | 40 ++++- src/sagemaker/serve/utils/optimize_utils.py | 30 ++-- .../serve/builder/test_js_builder.py | 159 ++++++++++++++++++ .../serve/builder/test_model_builder.py | 127 ++++++++++++-- .../serve/utils/test_optimize_utils.py | 15 +- 6 files changed, 370 insertions(+), 45 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index cfb43b813a..54974fd9f4 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -718,24 +718,36 @@ def _optimize_for_jumpstart( f"Model '{self.model}' requires accepting end-user license agreement (EULA)." ) - is_compilation = (not quantization_config) and ( - (compilation_config is not None) or _is_inferentia_or_trainium(instance_type) + is_compilation = (compilation_config is not None) or _is_inferentia_or_trainium( + instance_type ) pysdk_model_env_vars = dict() if is_compilation: pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) - optimization_config, override_env = _extract_optimization_config_and_env( - quantization_config, compilation_config + # optimization_config can contain configs for both quantization and compilation + optimization_config, quantization_override_env, compilation_override_env = ( + _extract_optimization_config_and_env(quantization_config, compilation_config) ) - if not optimization_config and is_compilation: - override_env = override_env or pysdk_model_env_vars - optimization_config = { - "ModelCompilationConfig": { - "OverrideEnvironment": override_env, - } - } + if ( + not optimization_config or not optimization_config.get("ModelCompilationConfig") + ) and is_compilation: + # Ensure optimization_config exists + if not optimization_config: + optimization_config = {} + + # Fallback to default if override_env is None or empty + if not compilation_override_env: + compilation_override_env = pysdk_model_env_vars + + # Update optimization_config with ModelCompilationConfig + override_compilation_config = ( + {"OverrideEnvironment": compilation_override_env} + if compilation_override_env + else {} + ) + optimization_config["ModelCompilationConfig"] = override_compilation_config if speculative_decoding_config: self._set_additional_model_source(speculative_decoding_config) @@ -766,7 +778,7 @@ def _optimize_for_jumpstart( "OptimizationJobName": job_name, "ModelSource": model_source, "DeploymentInstanceType": self.instance_type, - "OptimizationConfigs": [optimization_config], + "OptimizationConfigs": [{k: v} for k, v in optimization_config.items()], "OutputConfig": output_config, "RoleArn": self.role_arn, } @@ -789,7 +801,13 @@ def _optimize_for_jumpstart( "AcceptEula": True } - optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env) + optimization_env_vars = _update_environment_variables( + optimization_env_vars, + { + **(quantization_override_env or {}), + **(compilation_override_env or {}), + }, + ) if optimization_env_vars: self.pysdk_model.env.update(optimization_env_vars) if quantization_config or is_compilation: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index d1f1ab6ba2..70f8466695 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1235,9 +1235,6 @@ def _model_builder_optimize_wrapper( if self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") - if quantization_config and compilation_config: - raise ValueError("Quantization config and compilation config are mutually exclusive.") - self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.instance_type = instance_type or self.instance_type self.role_arn = role_arn or self.role_arn @@ -1279,6 +1276,28 @@ def _model_builder_optimize_wrapper( ) if input_args: + optimization_instance_type = input_args["DeploymentInstanceType"] + + # Compilation using TRTLLM and Llama-3.1 is currently not supported. + # TRTLLM is used by Neo if the following are provided: + # 1) a GPU instance type + # 2) compilation config + gpu_instance_families = ["g4", "g5", "p4d"] + is_gpu_instance = optimization_instance_type and any( + gpu_instance_family in optimization_instance_type + for gpu_instance_family in gpu_instance_families + ) + + # HF Model ID format = "meta-llama/Meta-Llama-3.1-8B" + # JS Model ID format = "meta-textgeneration-llama-3-1-8b" + llama_3_1_keywords = ["llama-3.1", "llama-3-1"] + is_llama_3_1 = self.model and any( + keyword in self.model.lower() for keyword in llama_3_1_keywords + ) + + if is_gpu_instance and self.model and is_llama_3_1 and self.is_compiled: + raise ValueError("Compilation is not supported for Llama-3.1 with a GPU instance.") + self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) job_status = self.sagemaker_session.wait_for_optimization_job(job_name) return _generate_optimized_model(self.pysdk_model, job_status) @@ -1342,11 +1361,18 @@ def _optimize_for_hf( model_source = _generate_model_source(self.pysdk_model.model_data, False) create_optimization_job_args["ModelSource"] = model_source - optimization_config, override_env = _extract_optimization_config_and_env( - quantization_config, compilation_config + optimization_config, quantization_override_env, compilation_override_env = ( + _extract_optimization_config_and_env(quantization_config, compilation_config) + ) + create_optimization_job_args["OptimizationConfigs"] = [ + {k: v} for k, v in optimization_config.items() + ] + self.pysdk_model.env.update( + { + **(quantization_override_env or {}), + **(compilation_override_env or {}), + } ) - create_optimization_job_args["OptimizationConfigs"] = [optimization_config] - self.pysdk_model.env.update(override_env) output_config = {"S3OutputLocation": output_path} if kms_key: diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 5781c0bade..d0c6314af3 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -260,7 +260,7 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool: def _extract_optimization_config_and_env( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None -) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]: +) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: """Extracts optimization config and environment variables. Args: @@ -271,15 +271,25 @@ def _extract_optimization_config_and_env( Optional[Tuple[Optional[Dict], Optional[Dict]]]: The optimization config and environment variables. """ - if quantization_config: - return {"ModelQuantizationConfig": quantization_config}, quantization_config.get( - "OverrideEnvironment" - ) - if compilation_config: - return {"ModelCompilationConfig": compilation_config}, compilation_config.get( - "OverrideEnvironment" - ) - return None, None + optimization_config = {} + quantization_override_env = ( + quantization_config.get("OverrideEnvironment", {}) if quantization_config else None + ) + compilation_override_env = ( + compilation_config.get("OverrideEnvironment", {}) if compilation_config else None + ) + + if quantization_config is not None: + optimization_config["ModelQuantizationConfig"] = quantization_config + + if compilation_config is not None: + optimization_config["ModelCompilationConfig"] = compilation_config + + # Return both dicts and environment variable if either is present + if optimization_config: + return optimization_config, quantization_override_env, compilation_override_env + + return None, None, None def _custom_speculative_decoding( diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 248955c273..e1eeb2e921 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -1198,6 +1198,65 @@ def test_optimize_quantize_for_jumpstart( self.assertIsNotNone(out_put) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_quantize_and_compile_for_jumpstart( + self, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b", + } + + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.config_name = "config_name" + mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config} + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_jumpstart( + accept_eula=True, + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + env_vars={ + "OPTION_TENSOR_PARALLEL_DEGREE": "1", + "OPTION_MAX_ROLLING_BATCH_SIZE": "2", + }, + output_path="s3://bucket/code/", + ) + + self.assertIsNotNone(out_put) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) @patch( @@ -1383,3 +1442,103 @@ def test_optimize_compile_for_jumpstart_with_neuron_env( self.assertEqual(optimized_model.env["OPTION_ROLLING_BATCH"], "auto") self.assertEqual(optimized_model.env["OPTION_MAX_ROLLING_BATCH_SIZE"], "4") self.assertEqual(optimized_model.env["OPTION_NEURON_OPTIMIZE_LEVEL"], "2") + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_without_compilation_config( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b", + } + + mock_js_model.return_value = MagicMock() + mock_js_model.return_value.env = { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + } + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.config_name = "config_name" + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = { + "config_name": mock_metadata_config + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.24xlarge", + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + self.assertEqual(optimized_model.env["SAGEMAKER_PROGRAM"], "inference.py") + self.assertEqual(optimized_model.env["ENDPOINT_SERVER_TIMEOUT"], "3600") + self.assertEqual(optimized_model.env["MODEL_CACHE_ROOT"], "/opt/ml/model") + self.assertEqual(optimized_model.env["SAGEMAKER_ENV"], "1") + self.assertEqual(optimized_model.env["HF_MODEL_ID"], "/opt/ml/model") + self.assertEqual(optimized_model.env["SAGEMAKER_MODEL_SERVER_WORKERS"], "1") diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index b50aa17c34..fffb548245 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -2650,21 +2650,75 @@ def test_optimize_local_mode(self, mock_get_serve_setting): ), ) + @patch.object(ModelBuilder, "_prepare_for_mode") @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) - def test_optimize_exclusive_args(self, mock_get_serve_setting): - mock_sagemaker_session = Mock() + def test_optimize_for_hf_with_both_quantization_and_compilation( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + model_builder = ModelBuilder( - model="meta-textgeneration-llama-3-70b", - sagemaker_session=mock_sagemaker_session, + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", ) - self.assertRaisesRegex( - ValueError, - "Quantization config and compilation config are mutually exclusive.", - lambda: model_builder.optimize( - quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, - compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, - ), + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ) + + self.assertEqual(model_builder.env_vars["HF_TOKEN"], "token") + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual(model_builder.pysdk_model.env["OPTION_TENSOR_PARALLEL_DEGREE"], "2") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"} + } + }, + { + "ModelCompilationConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"} + } + }, + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, ) @patch.object(ModelBuilder, "_prepare_for_mode") @@ -2786,3 +2840,54 @@ def test_optimize_for_hf_without_custom_s3_path( "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, }, ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-1-8B-Instruct"} + + sample_input = {"inputs": "dummy prompt", "parameters": {}} + + sample_output = [{"generated_text": "dummy response"}] + + dummy_schema_builder = SchemaBuilder(sample_input, sample_output) + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-1-8B-Instruct", + schema_builder=dummy_schema_builder, + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + self.assertRaisesRegex( + ValueError, + "Compilation is not supported for Llama-3.1 with a GPU instance.", + lambda: model_builder.optimize( + job_name="job_name-123", + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ), + ) diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index a8dc6d74f4..95e3d82fef 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected): @pytest.mark.parametrize( - "quantization_config, compilation_config, expected_config, expected_env", + "quantization_config, compilation_config, expected_config, expected_quant_env, expected_compilation_env", [ ( None, @@ -277,6 +277,7 @@ def test_is_s3_uri(s3_uri, expected): } }, }, + None, { "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, @@ -298,16 +299,22 @@ def test_is_s3_uri(s3_uri, expected): { "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, + None, ), - (None, None, None, None), + (None, None, None, None, None), ], ) def test_extract_optimization_config_and_env( - quantization_config, compilation_config, expected_config, expected_env + quantization_config, + compilation_config, + expected_config, + expected_quant_env, + expected_compilation_env, ): assert _extract_optimization_config_and_env(quantization_config, compilation_config) == ( expected_config, - expected_env, + expected_quant_env, + expected_compilation_env, ) From cf70f596bf9e2a7976cbfc5716581ad64d5f141f Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Thu, 7 Nov 2024 18:12:42 -0800 Subject: [PATCH 02/25] Require EULA acceptance when using a gated 1p draft model via ModelBuilder. --- .../serve/builder/jumpstart_builder.py | 10 +- src/sagemaker/serve/builder/model_builder.py | 16 +++ src/sagemaker/serve/utils/optimize_utils.py | 133 +++++++++++++++++- .../serve/builder/test_js_builder.py | 54 ++++++- tests/unit/sagemaker/serve/constants.py | 102 ++++++++++++++ .../serve/utils/test_optimize_utils.py | 16 +++ 6 files changed, 321 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 54974fd9f4..0e09e351c0 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -48,6 +48,7 @@ _custom_speculative_decoding, SPECULATIVE_DRAFT_MODEL, _is_inferentia_or_trainium, + _validate_and_set_eula_for_draft_model_sources, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -733,10 +734,6 @@ def _optimize_for_jumpstart( if ( not optimization_config or not optimization_config.get("ModelCompilationConfig") ) and is_compilation: - # Ensure optimization_config exists - if not optimization_config: - optimization_config = {} - # Fallback to default if override_env is None or empty if not compilation_override_env: compilation_override_env = pysdk_model_env_vars @@ -867,6 +864,11 @@ def _set_additional_model_source( "Cannot find deployment config compatible for optimization job." ) + _validate_and_set_eula_for_draft_model_sources( + pysdk_model=self.pysdk_model, + accept_eula=speculative_decoding_config.get("AcceptEula"), + ) + self.pysdk_model.env.update( {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"} ) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 70f8466695..bc8bef1626 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -99,6 +99,7 @@ validate_image_uri_and_hardware, ) from sagemaker.utils import Tags +from sagemaker.serve.utils.optimize_utils import _validate_and_set_eula_for_draft_model_sources from sagemaker.workflow.entities import PipelineVariable from sagemaker.huggingface.llm_utils import ( get_huggingface_model_metadata, @@ -589,6 +590,21 @@ def _model_builder_deploy_wrapper( model_server=self.model_server, ) + if self.deployment_config: + accept_draft_model_eula = kwargs.get("accept_draft_model_eula", False) + try: + _validate_and_set_eula_for_draft_model_sources( + pysdk_model=self, + accept_eula=accept_draft_model_eula, + ) + except ValueError as e: + logger.error( + "This deployment tried to use a gated draft model but the EULA was not " + "accepted. Please review the EULA, set accept_draft_model_eula to True, " + "and try again." + ) + raise e + if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True predictor = self._original_deploy( diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index d0c6314af3..53fe1a87ab 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -172,6 +172,60 @@ def _extract_speculative_draft_model_provider( return "sagemaker" +def _extract_additional_model_data_source_s3_uri( + additional_model_data_source: Optional[Dict] = None, +) -> Optional[str]: + """Extracts model data source s3 uri from a model data source in Pascal case. + + Args: + additional_model_data_source (Optional[Dict]): A model data source. + + Returns: + str: S3 uri of the model resources. + """ + if ( + additional_model_data_source is None + or additional_model_data_source.get("S3DataSource", None) is None + ): + return None + + return additional_model_data_source.get("S3DataSource").get("S3Uri", None) + + +def _extract_deployment_config_additional_model_data_source_s3_uri( + additional_model_data_source: Optional[Dict] = None, +) -> Optional[str]: + """Extracts model data source s3 uri from a model data source in snake case. + + Args: + additional_model_data_source (Optional[Dict]): A model data source. + + Returns: + str: S3 uri of the model resources. + """ + if ( + additional_model_data_source is None + or additional_model_data_source.get("s3_data_source", None) is None + ): + return None + + return additional_model_data_source.get("s3_data_source").get("s3_uri", None) + + +def _is_draft_model_gated( + draft_model_config: Optional[Dict] = None, +) -> bool: + """Extracts model gated-ness from draft model data source. + + Args: + draft_model_config (Optional[Dict]): A model data source. + + Returns: + bool: Whether the draft model is gated or not. + """ + return draft_model_config.get("hosting_eula_key", None) + + def _extracts_and_validates_speculative_model_source( speculative_decoding_config: Dict, ) -> str: @@ -289,7 +343,7 @@ def _extract_optimization_config_and_env( if optimization_config: return optimization_config, quantization_override_env, compilation_override_env - return None, None, None + return {}, None, None def _custom_speculative_decoding( @@ -310,6 +364,8 @@ def _custom_speculative_decoding( speculative_decoding_config ) + accept_eula = speculative_decoding_config.get("AcceptEula", False) + if _is_s3_uri(additional_model_source): channel_name = _generate_channel_name(model.additional_model_data_sources) speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}" @@ -326,3 +382,78 @@ def _custom_speculative_decoding( ) return model + + +def _validate_and_set_eula_for_draft_model_sources( + pysdk_model: Model, + accept_eula: bool = False, +): + """Validates whether the EULA has been accepted for gated additional draft model sources. + + If accepted, updates the model data source's model access config. + + Args: + pysdk_model (Model): The model whose additional model data sources to check. + accept_eula (bool): EULA acceptance for the draft model. + """ + if not pysdk_model: + return + + deployment_config_draft_model_sources = ( + pysdk_model.deployment_config.get("DeploymentArgs", {}) + .get("AdditionalDataSources", {}) + .get("speculative_decoding", []) + if pysdk_model.deployment_config + else None + ) + pysdk_model_additional_model_sources = pysdk_model.additional_model_data_sources + + if not deployment_config_draft_model_sources or not pysdk_model_additional_model_sources: + return + + # Gated/ungated classification is only available through deployment_config. + # Thus we must check each draft model in the deployment_config and see if it is set + # as an additional model data source on the PySDK model itself. + model_access_config_updated = False + for source in deployment_config_draft_model_sources: + if source.get("channel_name") != "draft_model": + continue + + if not _is_draft_model_gated(source): + continue + + deployment_config_draft_model_source_s3_uri = ( + _extract_deployment_config_additional_model_data_source_s3_uri(source) + ) + + # If EULA is accepted, proceed with modifying the draft model data source + for additional_source in pysdk_model_additional_model_sources: + if additional_source.get("ChannelName") != "draft_model": + continue + + # Verify the pysdk model source and deployment config model source match + pysdk_model_source_s3_uri = _extract_additional_model_data_source_s3_uri( + additional_source + ) + if deployment_config_draft_model_source_s3_uri not in pysdk_model_source_s3_uri: + continue + + if not accept_eula: + raise ValueError( + "Gated draft model requires accepting end-user license agreement (EULA)." + ) + + # Set ModelAccessConfig.AcceptEula to True + updated_source = additional_source.copy() + updated_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True} + + index = pysdk_model.additional_model_data_sources.index(additional_source) + pysdk_model.additional_model_data_sources[index] = updated_source + + model_access_config_updated = True + break + + if model_access_config_updated: + break + + return diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index e1eeb2e921..265907db45 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -25,7 +25,10 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) -from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS +from tests.unit.sagemaker.serve.constants import ( + DEPLOYMENT_CONFIGS, + OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, +) mock_model_id = "huggingface-llm-amazon-falconlite" mock_t5_model_id = "google/flan-t5-xxl" @@ -1198,6 +1201,51 @@ def test_optimize_quantize_for_jumpstart( self.assertIsNotNone(out_put) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_gated_draft_model_for_jumpstart_with_accept_eula_false( + self, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + self.assertRaises( + ValueError, + model_builder._optimize_for_jumpstart( + accept_eula=True, + speculative_decoding_config={"Provider": "sagemaker", "AcceptEula": False}, + ), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) def test_optimize_quantize_and_compile_for_jumpstart( @@ -1248,10 +1296,6 @@ def test_optimize_quantize_and_compile_for_jumpstart( "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, }, compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, - env_vars={ - "OPTION_TENSOR_PARALLEL_DEGREE": "1", - "OPTION_MAX_ROLLING_BATCH_SIZE": "2", - }, output_path="s3://bucket/code/", ) diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index 5a4679747b..c473750411 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -165,3 +165,105 @@ }, }, ] +OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL = { + "DeploymentConfigName": "lmi-optimized", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "djl-inference:0.29.0-lmi11.0.0-cu124", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-1-70b/artifacts/inference-prepack/v2.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "ModelPackageArn": None, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.g6.2xlarge", + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 131072, + "NumberOfAcceleratorDevicesRequired": 1, + }, + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "AdditionalDataSources": { + "speculative_decoding": [ + { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, + } + ] + }, + }, + "AccelerationConfigs": [ + { + "type": "Compilation", + "enabled": False, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1", + } + }, + }, + { + "type": "Speculative-Decoding", + "enabled": True, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "LMI v11 does not support Speculative Decoding for TRT", + } + }, + }, + { + "type": "Quantization", + "enabled": False, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1", + } + }, + }, + ], + "BenchmarkMetrics": {"ml.g6.2xlarge": None}, +} +GATED_DRAFT_MODEL_CONFIG = { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, +} +NON_GATED_DRAFT_MODEL_CONFIG = { + "channel_name": "draft_model", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://sagemaker-sd-models-beta-us-west-2/" + "sagemaker-speculative-decoding-llama3-small-v3/", + }, +} diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 95e3d82fef..8b11b40060 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -31,6 +31,11 @@ _is_optimized, _custom_speculative_decoding, _is_inferentia_or_trainium, + _is_draft_model_gated, +) +from tests.unit.sagemaker.serve.constants import ( + GATED_DRAFT_MODEL_CONFIG, + NON_GATED_DRAFT_MODEL_CONFIG, ) mock_optimization_job_output = { @@ -260,6 +265,17 @@ def test_is_s3_uri(s3_uri, expected): assert _is_s3_uri(s3_uri) == expected +@pytest.mark.parametrize( + "draft_model_config, expected", + [ + (GATED_DRAFT_MODEL_CONFIG, NON_GATED_DRAFT_MODEL_CONFIG), + (True, False), + ], +) +def test_is_draft_model_gated(draft_model_config, expected): + assert _is_draft_model_gated(draft_model_config, expected) + + @pytest.mark.parametrize( "quantization_config, compilation_config, expected_config, expected_quant_env, expected_compilation_env", [ From fcb5092165112dee4f4c926bf909aecdde8dce1f Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Fri, 8 Nov 2024 02:52:24 +0000 Subject: [PATCH 03/25] add accept_draft_model_eula to JumpStartModel when deployment config with gated draft model is selected --- src/sagemaker/jumpstart/factory/model.py | 48 ++++++++++++++++++- src/sagemaker/jumpstart/model.py | 39 ++++++++++----- src/sagemaker/jumpstart/types.py | 7 ++- .../serve/builder/jumpstart_builder.py | 8 +++- 4 files changed, 86 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index a5e9e1b6a4..99cdaf6cab 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -58,6 +58,7 @@ update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + get_jumpstart_content_bucket, ) from sagemaker.jumpstart.factory.utils import ( @@ -70,7 +71,13 @@ from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session -from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags +from sagemaker.utils import ( + camel_case_to_pascal_case, + name_from_base, + format_tags, + Tags, + get_domain_for_region, +) from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker import resource_requirements @@ -556,6 +563,37 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta return kwargs +def _apply_accept_eula_on_model_data_source( + model_data_source: Dict[str, Any], + model_id: str, + region: str, + accept_eula: bool +): + """Sets AcceptEula to True for gated speculative decoding models""" + + mutable_model_data_source = model_data_source.copy() + + hosting_eula_key = mutable_model_data_source.get("hosting_eula_key") + del mutable_model_data_source["hosting_eula_key"] + + if not hosting_eula_key: + return mutable_model_data_source + + if not accept_eula: + raise ValueError( + ( + f"The set deployment config comes optimized with an additional model data source " + f"'{model_id}' that requires accepting end-user license agreement (EULA). " + f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." + f"{get_domain_for_region(region)}" + f"/{hosting_eula_key} for terms of use. Please set `accept_eula=True` once acknowledged." + ) + ) + + mutable_model_data_source["model_access_config"] = {"accept_eula": accept_eula} + return mutable_model_data_source + + def _add_additional_model_data_sources_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: @@ -568,7 +606,11 @@ def _add_additional_model_data_sources_to_kwargs( data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region)) api_shape_additional_model_data_sources = ( [ - camel_case_to_pascal_case(data_source.to_json()) + camel_case_to_pascal_case( + _apply_accept_eula_on_model_data_source( + data_source.to_json(), kwargs.model_id, kwargs.region, kwargs.accept_draft_model_eula, + ) + ) for data_source in speculative_decoding_data_sources ] if specs.get_speculative_decoding_s3_data_sources() @@ -858,6 +900,7 @@ def get_init_kwargs( resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, + accept_draft_model_eula: Optional[bool] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -892,6 +935,7 @@ def get_init_kwargs( resources=resources, config_name=config_name, additional_model_data_sources=additional_model_data_sources, + accept_draft_model_eula=accept_draft_model_eula, ) model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( kwargs=model_init_kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 486079718b..a42333320c 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -111,6 +111,7 @@ def __init__( resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, + accept_draft_model_eula: Optional[bool] = None, ): """Initializes a ``JumpStartModel``. @@ -301,6 +302,10 @@ def __init__( optionally applied to the model. additional_model_data_sources (Optional[Dict[str, Any]]): Additional location of SageMaker model data (default: None). + accept_draft_model_eula (bool): For draft models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_draft_model_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some Raises: ValueError: If the model ID is not recognized by JumpStart. """ @@ -360,6 +365,7 @@ def _validate_model_id_and_type(): resources=resources, config_name=config_name, additional_model_data_sources=additional_model_data_sources, + accept_draft_model_eula=accept_draft_model_eula ) self.orig_predictor_cls = predictor_cls @@ -456,7 +462,9 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) - def set_deployment_config(self, config_name: str, instance_type: str) -> None: + def set_deployment_config( + self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False + ) -> None: """Sets the deployment config to apply to the model. Args: @@ -466,6 +474,8 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: instance_type (str): The instance_type that the model will use after setting the config. + accept_draft_model_eula (Optional[bool]): + If the config selected comes with a gated additional model data source. """ self.__init__( model_id=self.model_id, @@ -474,6 +484,7 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: config_name=config_name, sagemaker_session=self.sagemaker_session, role=self.role, + accept_draft_model_eula=accept_draft_model_eula, ) @property @@ -540,12 +551,16 @@ def attach( inferred_model_id = inferred_model_version = inferred_inference_component_name = None if inference_component_name is None or model_id is None or model_version is None: - inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = ( - get_model_info_from_endpoint( - endpoint_name=endpoint_name, - inference_component_name=inference_component_name, - sagemaker_session=sagemaker_session, - ) + ( + inferred_model_id, + inferred_model_version, + inferred_inference_component_name, + _, + _, + ) = get_model_info_from_endpoint( + endpoint_name=endpoint_name, + inference_component_name=inference_component_name, + sagemaker_session=sagemaker_session, ) model_id = model_id or inferred_model_id @@ -1016,10 +1031,11 @@ def _get_deployment_configs( ) if metadata_config.benchmark_metrics: - err, metadata_config.benchmark_metrics = ( - add_instance_rate_stats_to_benchmark_metrics( - self.region, metadata_config.benchmark_metrics - ) + ( + err, + metadata_config.benchmark_metrics, + ) = add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics ) config_components = metadata_config.config_components.get(config_name) @@ -1042,6 +1058,7 @@ def _get_deployment_configs( region=self.region, model_version=self.model_version, hub_arn=self.hub_arn, + accept_draft_model_eula=True, ) deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e77c407372..f4fa587c9d 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1083,7 +1083,7 @@ class AdditionalModelDataSource(JumpStartDataHolderType): SERIALIZATION_EXCLUSION_SET: Set[str] = set() - __slots__ = ["channel_name", "s3_data_source"] + __slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"] def __init__(self, spec: Dict[str, Any]): """Initializes a AdditionalModelDataSource object. @@ -1101,6 +1101,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: """ self.channel_name: str = json_obj["channel_name"] self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"]) + self.hosting_eula_key: str = json_obj.get("hosting_eula_key") def to_json(self, exclude_keys=True) -> Dict[str, Any]: """Returns json representation of AdditionalModelDataSource object.""" @@ -2116,6 +2117,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "hub_content_type", "model_reference_arn", "specs", + "accept_draft_model_eula", ] SERIALIZATION_EXCLUSION_SET = { @@ -2131,6 +2133,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "training_instance_type", "config_name", "hub_content_type", + "accept_draft_model_eula", } def __init__( @@ -2165,6 +2168,7 @@ def __init__( resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, + accept_draft_model_eula: Optional[bool] = False ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -2198,6 +2202,7 @@ def __init__( self.resources = resources self.config_name = config_name self.additional_model_data_sources = additional_model_data_sources + self.accept_draft_model_eula = accept_draft_model_eula class JumpStartModelDeployKwargs(JumpStartKwargs): diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 0e09e351c0..8593f1913b 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -502,7 +502,9 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) - def set_deployment_config(self, config_name: str, instance_type: str) -> None: + def set_deployment_config( + self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False + ) -> None: """Sets the deployment config to apply to the model. Args: @@ -512,11 +514,13 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: instance_type (str): The instance_type that the model will use after setting the config. + accept_draft_model_eula (Optional[bool]): + If the config selected comes with a gated additional model data source. """ if not hasattr(self, "pysdk_model") or self.pysdk_model is None: raise Exception("Cannot set deployment config to an uninitialized model.") - self.pysdk_model.set_deployment_config(config_name, instance_type) + self.pysdk_model.set_deployment_config(config_name, instance_type, accept_draft_model_eula) self.deployment_config_name = config_name self.instance_type = instance_type From 9489b8d9440c7765882bb621f7a564598ffe039c Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Fri, 8 Nov 2024 04:30:59 +0000 Subject: [PATCH 04/25] add map of valid optimization combinations --- .../check_optimization_configurations.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/sagemaker/serve/validations/check_optimization_configurations.py diff --git a/src/sagemaker/serve/validations/check_optimization_configurations.py b/src/sagemaker/serve/validations/check_optimization_configurations.py new file mode 100644 index 0000000000..82477d02a1 --- /dev/null +++ b/src/sagemaker/serve/validations/check_optimization_configurations.py @@ -0,0 +1,38 @@ +SUPPORTED_OPTIMIZATION_CONFIGURATIONS = { + "trt": { + "supported_instance_families": ["p4d", "p4de", "p5", "g5", "g6"], + "compilation": True, + "quantization": { + "awq": True, + "fp8": True, + "gptq": False, + "smooth_quant": True + }, + "speculative_decoding": False, + "sharding": False + }, + "vllm": { + "supported_instance_families": ["p4d", "p4de", "p5", "g5", "g6"], + "compilation": False, + "quantization": { + "awq": True, + "fp8": True, + "gptq": False, + "smooth_quant": False + }, + "speculative_decoding": True, + "sharding": True + }, + "neuron": { + "supported_instance_families": ["inf2", "trn1", "trn1n"], + "compilation": True, + "quantization": { + "awq": False, + "fp8": False, + "gptq": False, + "smooth_quant": False + }, + "speculative_decoding": False, + "sharding": False + } +} From 5512c268d9c6d4bfd268be2ef0e1fb60573c40da Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Fri, 8 Nov 2024 16:29:58 -0800 Subject: [PATCH 05/25] Add ModelBuilder support for JumpStart-provided draft models. --- .../serve/builder/jumpstart_builder.py | 24 ++++-- src/sagemaker/serve/builder/model_builder.py | 31 +++----- src/sagemaker/serve/utils/optimize_utils.py | 76 +++++++++++++++++-- 3 files changed, 100 insertions(+), 31 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 8593f1913b..036ea6e011 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -49,6 +49,7 @@ SPECULATIVE_DRAFT_MODEL, _is_inferentia_or_trainium, _validate_and_set_eula_for_draft_model_sources, + _jumpstart_speculative_decoding, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -503,7 +504,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): ) def set_deployment_config( - self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False + self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False ) -> None: """Sets the deployment config to apply to the model. @@ -735,6 +736,10 @@ def _optimize_for_jumpstart( optimization_config, quantization_override_env, compilation_override_env = ( _extract_optimization_config_and_env(quantization_config, compilation_config) ) + + if not optimization_config: + optimization_config = {} + if ( not optimization_config or not optimization_config.get("ModelCompilationConfig") ) and is_compilation: @@ -844,6 +849,7 @@ def _set_additional_model_source( """ if speculative_decoding_config: model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) if model_provider == "sagemaker": @@ -868,17 +874,23 @@ def _set_additional_model_source( "Cannot find deployment config compatible for optimization job." ) - _validate_and_set_eula_for_draft_model_sources( - pysdk_model=self.pysdk_model, - accept_eula=speculative_decoding_config.get("AcceptEula"), - ) + _validate_and_set_eula_for_draft_model_sources( + pysdk_model=self.pysdk_model, + accept_eula=speculative_decoding_config.get("AcceptEula"), + ) self.pysdk_model.env.update( - {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}"} + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} ) self.pysdk_model.add_tags( {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, ) + elif model_provider == "jumpstart": + _jumpstart_speculative_decoding( + model=self.pysdk_model, + speculative_decoding_config=speculative_decoding_config, + sagemaker_session=self.sagemaker_session, + ) else: self.pysdk_model = _custom_speculative_decoding( self.pysdk_model, speculative_decoding_config, accept_eula diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index bc8bef1626..a331b606e3 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -76,6 +76,7 @@ _is_s3_uri, _custom_speculative_decoding, _extract_speculative_draft_model_provider, + _jumpstart_speculative_decoding, ) from sagemaker.serve.utils.predictors import _get_local_mode_predictor from sagemaker.serve.utils.hardware_detector import ( @@ -99,7 +100,6 @@ validate_image_uri_and_hardware, ) from sagemaker.utils import Tags -from sagemaker.serve.utils.optimize_utils import _validate_and_set_eula_for_draft_model_sources from sagemaker.workflow.entities import PipelineVariable from sagemaker.huggingface.llm_utils import ( get_huggingface_model_metadata, @@ -590,21 +590,6 @@ def _model_builder_deploy_wrapper( model_server=self.model_server, ) - if self.deployment_config: - accept_draft_model_eula = kwargs.get("accept_draft_model_eula", False) - try: - _validate_and_set_eula_for_draft_model_sources( - pysdk_model=self, - accept_eula=accept_draft_model_eula, - ) - except ValueError as e: - logger.error( - "This deployment tried to use a gated draft model but the EULA was not " - "accepted. Please review the EULA, set accept_draft_model_eula to True, " - "and try again." - ) - raise e - if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True predictor = self._original_deploy( @@ -1358,9 +1343,17 @@ def _optimize_for_hf( Returns: Optional[Dict[str, Any]]: Model optimization job input arguments. """ - self.pysdk_model = _custom_speculative_decoding( - self.pysdk_model, speculative_decoding_config, False - ) + if speculative_decoding_config: + if speculative_decoding_config.get("ModelProvider", "") == "JumpStart": + _jumpstart_speculative_decoding( + model=self.pysdk_model, + speculative_decoding_config=speculative_decoding_config, + sagemaker_session=self.sagemaker_session, + ) + else: + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, False + ) if quantization_config or compilation_config: create_optimization_job_args = { diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 53fe1a87ab..55443d174c 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -17,8 +17,10 @@ import logging from typing import Dict, Any, Optional, Union, List, Tuple -from sagemaker import Model +from sagemaker import Model, Session from sagemaker.enums import Tag +from sagemaker.jumpstart.utils import accessors, get_eula_message + logger = logging.getLogger(__name__) @@ -164,6 +166,9 @@ def _extract_speculative_draft_model_provider( if speculative_decoding_config is None: return None + if speculative_decoding_config.get("ModelProvider") == "JumpStart": + return "jumpstart" + if speculative_decoding_config.get( "ModelProvider" ) == "Custom" or speculative_decoding_config.get("ModelSource"): @@ -292,7 +297,7 @@ def _generate_additional_model_data_sources( }, } if accept_eula: - additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"ACCEPT_EULA": True} + additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True} return [additional_model_data_source] @@ -327,10 +332,10 @@ def _extract_optimization_config_and_env( """ optimization_config = {} quantization_override_env = ( - quantization_config.get("OverrideEnvironment", {}) if quantization_config else None + quantization_config.get("OverrideEnvironment") if quantization_config else None ) compilation_override_env = ( - compilation_config.get("OverrideEnvironment", {}) if compilation_config else None + compilation_config.get("OverrideEnvironment") if compilation_config else None ) if quantization_config is not None: @@ -343,7 +348,7 @@ def _extract_optimization_config_and_env( if optimization_config: return optimization_config, quantization_override_env, compilation_override_env - return {}, None, None + return None, None, None def _custom_speculative_decoding( @@ -364,7 +369,7 @@ def _custom_speculative_decoding( speculative_decoding_config ) - accept_eula = speculative_decoding_config.get("AcceptEula", False) + accept_eula = speculative_decoding_config.get("AcceptEula", accept_eula) if _is_s3_uri(additional_model_source): channel_name = _generate_channel_name(model.additional_model_data_sources) @@ -384,6 +389,65 @@ def _custom_speculative_decoding( return model +def _jumpstart_speculative_decoding( + model=Model, + speculative_decoding_config: Optional[Dict[str, Any]] = None, + sagemaker_session: Optional[Session] = None, +): + """Modifies the given model for speculative decoding config with JumpStart provider. + + Args: + model (Model): The model. + speculative_decoding_config (Optional[Dict]): The speculative decoding config. + sagemaker_session (Optional[Session]): Sagemaker session for execution. + """ + if speculative_decoding_config: + js_id = speculative_decoding_config.get("ModelID") + if not js_id: + raise ValueError( + "`ModelID` is a required field in `speculative_decoding_config` when " + "using JumpStart as draft model provider." + ) + model_version = speculative_decoding_config.get("ModelVersion", "*") + accept_eula = speculative_decoding_config.get("AcceptEula", False) + channel_name = _generate_channel_name(model.additional_model_data_sources) + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + model_id=js_id, + version=model_version, + region=sagemaker_session.boto_region_name, + sagemaker_session=sagemaker_session, + ) + model_spec_json = model_specs.to_json() + + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket() + + if model_spec_json.get("gated_bucket", False): + if not accept_eula: + eula_message = get_eula_message( + model_specs=model_specs, region=sagemaker_session.boto_region_name + ) + raise ValueError( + f"{eula_message} Please set `AcceptEula` to True in " + f"speculative_decoding_config once acknowledged." + ) + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket() + + key_prefix = model_spec_json.get("hosting_prepacked_artifact_key") + model.additional_model_data_sources = _generate_additional_model_data_sources( + f"s3://{js_bucket}/{key_prefix}", + channel_name, + accept_eula, + ) + + model.env.update( + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} + ) + model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"}, + ) + + def _validate_and_set_eula_for_draft_model_sources( pysdk_model: Model, accept_eula: bool = False, From c94a78b14f5cd3b540275639ebab14e9534c0a9e Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Fri, 8 Nov 2024 17:04:02 -0800 Subject: [PATCH 06/25] Tweak draft model EULA validations and messaging. Remove redundant deployment_config flow validation in optimize_utils in favor of the one directly on jumpstart/factory/model. --- .../serve/builder/jumpstart_builder.py | 28 +++---- src/sagemaker/serve/utils/optimize_utils.py | 77 +------------------ 2 files changed, 15 insertions(+), 90 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 036ea6e011..7b6001b912 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -48,7 +48,6 @@ _custom_speculative_decoding, SPECULATIVE_DRAFT_MODEL, _is_inferentia_or_trainium, - _validate_and_set_eula_for_draft_model_sources, _jumpstart_speculative_decoding, ) from sagemaker.serve.utils.predictors import ( @@ -837,9 +836,7 @@ def _is_gated_model(self, model=None) -> bool: return "private" in s3_uri def _set_additional_model_source( - self, - speculative_decoding_config: Optional[Dict[str, Any]] = None, - accept_eula: Optional[bool] = None, + self, speculative_decoding_config: Optional[Dict[str, Any]] = None ) -> None: """Set Additional Model Source to ``this`` model. @@ -849,6 +846,7 @@ def _set_additional_model_source( """ if speculative_decoding_config: model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + accept_draft_model_eula = speculative_decoding_config.get("AcceptEula", False) channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) @@ -865,20 +863,22 @@ def _set_additional_model_source( speculative_decoding_config ) if deployment_config: - self.pysdk_model.set_deployment_config( - config_name=deployment_config.get("DeploymentConfigName"), - instance_type=deployment_config.get("InstanceType"), - ) + try: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + accept_draft_model_eula=accept_draft_model_eula, + ) + except ValueError as e: + raise ValueError( + f"{e} If using speculative_decoding_config, " + "accept the EULA by setting `AcceptEula`=True." + ) else: raise ValueError( "Cannot find deployment config compatible for optimization job." ) - _validate_and_set_eula_for_draft_model_sources( - pysdk_model=self.pysdk_model, - accept_eula=speculative_decoding_config.get("AcceptEula"), - ) - self.pysdk_model.env.update( {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} ) @@ -893,7 +893,7 @@ def _set_additional_model_source( ) else: self.pysdk_model = _custom_speculative_decoding( - self.pysdk_model, speculative_decoding_config, accept_eula + self.pysdk_model, speculative_decoding_config, accept_draft_model_eula ) def _find_compatible_deployment_config( diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 55443d174c..05f467e856 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -428,7 +428,7 @@ def _jumpstart_speculative_decoding( model_specs=model_specs, region=sagemaker_session.boto_region_name ) raise ValueError( - f"{eula_message} Please set `AcceptEula` to True in " + f"{eula_message} Set `AcceptEula`=True in " f"speculative_decoding_config once acknowledged." ) js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket() @@ -446,78 +446,3 @@ def _jumpstart_speculative_decoding( model.add_tags( {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"}, ) - - -def _validate_and_set_eula_for_draft_model_sources( - pysdk_model: Model, - accept_eula: bool = False, -): - """Validates whether the EULA has been accepted for gated additional draft model sources. - - If accepted, updates the model data source's model access config. - - Args: - pysdk_model (Model): The model whose additional model data sources to check. - accept_eula (bool): EULA acceptance for the draft model. - """ - if not pysdk_model: - return - - deployment_config_draft_model_sources = ( - pysdk_model.deployment_config.get("DeploymentArgs", {}) - .get("AdditionalDataSources", {}) - .get("speculative_decoding", []) - if pysdk_model.deployment_config - else None - ) - pysdk_model_additional_model_sources = pysdk_model.additional_model_data_sources - - if not deployment_config_draft_model_sources or not pysdk_model_additional_model_sources: - return - - # Gated/ungated classification is only available through deployment_config. - # Thus we must check each draft model in the deployment_config and see if it is set - # as an additional model data source on the PySDK model itself. - model_access_config_updated = False - for source in deployment_config_draft_model_sources: - if source.get("channel_name") != "draft_model": - continue - - if not _is_draft_model_gated(source): - continue - - deployment_config_draft_model_source_s3_uri = ( - _extract_deployment_config_additional_model_data_source_s3_uri(source) - ) - - # If EULA is accepted, proceed with modifying the draft model data source - for additional_source in pysdk_model_additional_model_sources: - if additional_source.get("ChannelName") != "draft_model": - continue - - # Verify the pysdk model source and deployment config model source match - pysdk_model_source_s3_uri = _extract_additional_model_data_source_s3_uri( - additional_source - ) - if deployment_config_draft_model_source_s3_uri not in pysdk_model_source_s3_uri: - continue - - if not accept_eula: - raise ValueError( - "Gated draft model requires accepting end-user license agreement (EULA)." - ) - - # Set ModelAccessConfig.AcceptEula to True - updated_source = additional_source.copy() - updated_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True} - - index = pysdk_model.additional_model_data_sources.index(additional_source) - pysdk_model.additional_model_data_sources[index] = updated_source - - model_access_config_updated = True - break - - if model_access_config_updated: - break - - return From d10c475ecab60fe1b8c16aabec6441c0b5a9bfea Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Mon, 11 Nov 2024 15:48:56 -0800 Subject: [PATCH 07/25] Add "Auto" speculative decoding ModelProvider option; add validations to differentiate SageMaker/JumpStart draft models. --- src/sagemaker/jumpstart/factory/model.py | 13 ++--- .../serve/builder/jumpstart_builder.py | 32 ++++++++++-- src/sagemaker/serve/utils/optimize_utils.py | 50 +++++++++++++++++-- 3 files changed, 81 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 99cdaf6cab..e357fd22b0 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -564,10 +564,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta def _apply_accept_eula_on_model_data_source( - model_data_source: Dict[str, Any], - model_id: str, - region: str, - accept_eula: bool + model_data_source: Dict[str, Any], model_id: str, region: str, accept_eula: bool ): """Sets AcceptEula to True for gated speculative decoding models""" @@ -586,7 +583,8 @@ def _apply_accept_eula_on_model_data_source( f"'{model_id}' that requires accepting end-user license agreement (EULA). " f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." f"{get_domain_for_region(region)}" - f"/{hosting_eula_key} for terms of use. Please set `accept_eula=True` once acknowledged." + f"/{hosting_eula_key} for terms of use. Please set `accept_draft_model_eula=True` " + f"once acknowledged." ) ) @@ -608,7 +606,10 @@ def _add_additional_model_data_sources_to_kwargs( [ camel_case_to_pascal_case( _apply_accept_eula_on_model_data_source( - data_source.to_json(), kwargs.model_id, kwargs.region, kwargs.accept_draft_model_eula, + data_source.to_json(), + kwargs.model_id, + kwargs.region, + kwargs.accept_draft_model_eula, ) ) for data_source in speculative_decoding_data_sources diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 7b6001b912..ca9812c1d8 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -49,6 +49,8 @@ SPECULATIVE_DRAFT_MODEL, _is_inferentia_or_trainium, _jumpstart_speculative_decoding, + _deployment_config_contains_draft_model, + _is_draft_model_jumpstart_provided, ) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, @@ -850,7 +852,7 @@ def _set_additional_model_source( channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) - if model_provider == "sagemaker": + if model_provider in ["sagemaker", "auto"]: additional_model_data_sources = ( self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( "AdditionalDataSources" @@ -863,6 +865,15 @@ def _set_additional_model_source( speculative_decoding_config ) if deployment_config: + if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided( + deployment_config + ): + raise ValueError( + "No `Sagemaker` provided draft model was found for " + f"{self.model}. Try setting `ModelProvider` " + "to `Auto` instead." + ) + try: self.pysdk_model.set_deployment_config( config_name=deployment_config.get("DeploymentConfigName"), @@ -878,12 +889,21 @@ def _set_additional_model_source( raise ValueError( "Cannot find deployment config compatible for optimization job." ) + else: + if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided( + self.pysdk_model.deployment_config + ): + raise ValueError( + "No `Sagemaker` provided draft model was found for " + f"{self.model}. Try setting `ModelProvider` " + "to `Auto` instead." + ) self.pysdk_model.env.update( {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} ) self.pysdk_model.add_tags( - {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider}, ) elif model_provider == "jumpstart": _jumpstart_speculative_decoding( @@ -911,15 +931,17 @@ def _find_compatible_deployment_config( for deployment_config in self.pysdk_model.list_deployment_configs(): image_uri = deployment_config.get("deployment_config", {}).get("ImageUri") - if _is_image_compatible_with_optimization_job(image_uri): + if _is_image_compatible_with_optimization_job( + image_uri + ) and _deployment_config_contains_draft_model(deployment_config): if ( - model_provider == "sagemaker" + model_provider in ["sagemaker", "auto"] and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources") ) or model_provider == "custom": return deployment_config # There's no matching config from jumpstart to add sagemaker draft model location - if model_provider == "sagemaker": + if model_provider in ["sagemaker", "auto"]: return None # fall back to the default jumpstart model deployment config for optimization job diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 05f467e856..1859e6b589 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -60,6 +60,47 @@ def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri) +def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) -> bool: + """Checks whether a deployment config contains a speculative decoding draft model. + + Args: + deployment_config (Dict): The deployment config to check. + + Returns: + bool: Whether the deployment config contains a draft model or not. + """ + if deployment_config is None: + return False + deployment_args = deployment_config.get("DeploymentArgs", {}) + additional_data_sources = deployment_args.get("AdditionalDataSources") + if not additional_data_sources: + return False + return additional_data_sources.get("speculative_decoding", False) + + +def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool: + """Checks whether a deployment config's draft model is provided by JumpStart. + + Args: + deployment_config (Dict): The deployment config to check. + + Returns: + bool: Whether the draft model is provided by JumpStart or not. + """ + if deployment_config is None: + return False + + additional_model_data_sources = deployment_config.get("DeploymentArgs", {}).get( + "AdditionalDataSources" + ) + for source in additional_model_data_sources.get("speculative_decoding", []): + if source["channel_name"] == "draft_model": + if source.get("provider", {}).get("name") == "JumpStart": + return True + continue + return False + + def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model: """Generates a new optimization model. @@ -166,15 +207,18 @@ def _extract_speculative_draft_model_provider( if speculative_decoding_config is None: return None - if speculative_decoding_config.get("ModelProvider") == "JumpStart": + if speculative_decoding_config.get("ModelProvider").lower() == "jumpstart": return "jumpstart" if speculative_decoding_config.get( "ModelProvider" - ) == "Custom" or speculative_decoding_config.get("ModelSource"): + ).lower() == "custom" or speculative_decoding_config.get("ModelSource"): return "custom" - return "sagemaker" + if speculative_decoding_config.get("ModelProvider").lower() == "sagemaker": + return "sagemaker" + + return "auto" def _extract_additional_model_data_source_s3_uri( From 8fb27a0458ff1f023daebda12c6348ee76ee4b25 Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Mon, 11 Nov 2024 17:20:57 -0800 Subject: [PATCH 08/25] Fix JumpStartModel.AdditionalModelDataSource model access config assignment. --- src/sagemaker/jumpstart/factory/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index e357fd22b0..f97585749c 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -588,7 +588,9 @@ def _apply_accept_eula_on_model_data_source( ) ) - mutable_model_data_source["model_access_config"] = {"accept_eula": accept_eula} + mutable_model_data_source["s3_data_source"]["model_access_config"] = { + "accept_eula": accept_eula + } return mutable_model_data_source From 779f6d69ad34277fc3c0602266f63cddf0f94081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Tue, 12 Nov 2024 01:07:44 +0000 Subject: [PATCH 09/25] move the accept eula configurations into deploy flow --- src/sagemaker/jumpstart/factory/model.py | 46 ++--------------- src/sagemaker/jumpstart/model.py | 28 ++++++---- src/sagemaker/jumpstart/types.py | 9 ++-- src/sagemaker/jumpstart/utils.py | 51 ++++++++++++++++++- .../serve/builder/jumpstart_builder.py | 4 +- 5 files changed, 78 insertions(+), 60 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index f97585749c..73d5ad3f07 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Union +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -58,7 +59,6 @@ update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, - get_jumpstart_content_bucket, ) from sagemaker.jumpstart.factory.utils import ( @@ -563,37 +563,6 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta return kwargs -def _apply_accept_eula_on_model_data_source( - model_data_source: Dict[str, Any], model_id: str, region: str, accept_eula: bool -): - """Sets AcceptEula to True for gated speculative decoding models""" - - mutable_model_data_source = model_data_source.copy() - - hosting_eula_key = mutable_model_data_source.get("hosting_eula_key") - del mutable_model_data_source["hosting_eula_key"] - - if not hosting_eula_key: - return mutable_model_data_source - - if not accept_eula: - raise ValueError( - ( - f"The set deployment config comes optimized with an additional model data source " - f"'{model_id}' that requires accepting end-user license agreement (EULA). " - f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." - f"{get_domain_for_region(region)}" - f"/{hosting_eula_key} for terms of use. Please set `accept_draft_model_eula=True` " - f"once acknowledged." - ) - ) - - mutable_model_data_source["s3_data_source"]["model_access_config"] = { - "accept_eula": accept_eula - } - return mutable_model_data_source - - def _add_additional_model_data_sources_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: @@ -606,14 +575,7 @@ def _add_additional_model_data_sources_to_kwargs( data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region)) api_shape_additional_model_data_sources = ( [ - camel_case_to_pascal_case( - _apply_accept_eula_on_model_data_source( - data_source.to_json(), - kwargs.model_id, - kwargs.region, - kwargs.accept_draft_model_eula, - ) - ) + camel_case_to_pascal_case(data_source.to_json()) for data_source in speculative_decoding_data_sources ] if specs.get_speculative_decoding_s3_data_sources() @@ -693,6 +655,7 @@ def get_deploy_kwargs( training_config_name: Optional[str] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[List[ModelAccessConfig]] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -729,6 +692,7 @@ def get_deploy_kwargs( resources=resources, config_name=config_name, routing_config=routing_config, + model_access_configs=model_access_configs, ) deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs) deploy_kwargs.specs = verify_model_region_and_return_specs( @@ -903,7 +867,6 @@ def get_init_kwargs( resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, - accept_draft_model_eula: Optional[bool] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -938,7 +901,6 @@ def get_init_kwargs( resources=resources, config_name=config_name, additional_model_data_sources=additional_model_data_sources, - accept_draft_model_eula=accept_draft_model_eula, ) model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( kwargs=model_init_kwargs diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index a42333320c..05c6404501 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -18,6 +18,7 @@ import pandas as pd from botocore.exceptions import ClientError +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -51,6 +52,7 @@ add_instance_rate_stats_to_benchmark_metrics, deployment_config_response_data, _deployment_config_lru_cache, + _add_model_access_configs_to_model_data_sources, ) from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType @@ -111,7 +113,6 @@ def __init__( resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, - accept_draft_model_eula: Optional[bool] = None, ): """Initializes a ``JumpStartModel``. @@ -302,10 +303,6 @@ def __init__( optionally applied to the model. additional_model_data_sources (Optional[Dict[str, Any]]): Additional location of SageMaker model data (default: None). - accept_draft_model_eula (bool): For draft models that require a Model Access Config, specify True or - False to indicate whether model terms of use have been accepted. - The `accept_draft_model_eula` value must be explicitly defined as `True` in order to - accept the end-user license agreement (EULA) that some Raises: ValueError: If the model ID is not recognized by JumpStart. """ @@ -365,7 +362,6 @@ def _validate_model_id_and_type(): resources=resources, config_name=config_name, additional_model_data_sources=additional_model_data_sources, - accept_draft_model_eula=accept_draft_model_eula ) self.orig_predictor_cls = predictor_cls @@ -463,7 +459,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: ) def set_deployment_config( - self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False + self, config_name: str, instance_type: str ) -> None: """Sets the deployment config to apply to the model. @@ -483,8 +479,7 @@ def set_deployment_config( instance_type=instance_type, config_name=config_name, sagemaker_session=self.sagemaker_session, - role=self.role, - accept_draft_model_eula=accept_draft_model_eula, + role=self.role ) @property @@ -674,6 +669,7 @@ def deploy( managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[List[ModelAccessConfig]] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -770,6 +766,11 @@ def deploy( (Default: EndpointType.MODEL_BASED). routing_config (Optional[Dict]): Settings the control how the endpoint routes incoming traffic to the instances that the endpoint hosts. + model_access_configs (Optional[List[ModelAccessConfig]]): For models that require Model Access Configs, + provide one or multiple ModelAccessConfig objects to indicate whether model terms of use have been accepted. + The `AcceptEula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some. + (Default: None) Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. @@ -810,6 +811,7 @@ def deploy( model_type=self.model_type, config_name=self.config_name, routing_config=routing_config, + model_access_configs=model_access_configs, ) if ( self.model_type == JumpStartModelType.PROPRIETARY @@ -819,6 +821,13 @@ def deploy( f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." ) + self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources( + self.additional_model_data_sources, + deploy_kwargs.model_access_configs, + deploy_kwargs.model_id, + deploy_kwargs.region, + ) + try: predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) except ClientError as e: @@ -1058,7 +1067,6 @@ def _get_deployment_configs( region=self.region, model_version=self.model_version, hub_arn=self.hub_arn, - accept_draft_model_eula=True, ) deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index f4fa587c9d..c8edd45447 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -17,6 +17,7 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import ( S3_PREFIX, @@ -2117,7 +2118,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "hub_content_type", "model_reference_arn", "specs", - "accept_draft_model_eula", ] SERIALIZATION_EXCLUSION_SET = { @@ -2133,7 +2133,6 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "training_instance_type", "config_name", "hub_content_type", - "accept_draft_model_eula", } def __init__( @@ -2168,7 +2167,6 @@ def __init__( resources: Optional[ResourceRequirements] = None, config_name: Optional[str] = None, additional_model_data_sources: Optional[Dict[str, Any]] = None, - accept_draft_model_eula: Optional[bool] = False ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -2202,7 +2200,6 @@ def __init__( self.resources = resources self.config_name = config_name self.additional_model_data_sources = additional_model_data_sources - self.accept_draft_model_eula = accept_draft_model_eula class JumpStartModelDeployKwargs(JumpStartKwargs): @@ -2244,6 +2241,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "config_name", "routing_config", "specs", + "model_access_configs" ] SERIALIZATION_EXCLUSION_SET = { @@ -2257,6 +2255,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "sagemaker_session", "training_instance_type", "config_name", + "model_access_configs" } def __init__( @@ -2295,6 +2294,7 @@ def __init__( endpoint_type: Optional[EndpointType] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[List[CoreModelAccessConfig]] = None ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -2332,6 +2332,7 @@ def __init__( self.endpoint_type = endpoint_type self.config_name = config_name self.routing_config = routing_config + self.model_access_configs = model_access_configs class JumpStartEstimatorInitKwargs(JumpStartKwargs): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index b33d6563e5..1d94fdf06d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import + from copy import copy import logging import os @@ -22,6 +23,7 @@ from botocore.exceptions import ClientError from packaging.version import Version import botocore +from sagemaker_core.shapes import ModelAccessConfig import sagemaker from sagemaker.config.config_schema import ( MODEL_ENABLE_NETWORK_ISOLATION_PATH, @@ -55,6 +57,7 @@ TagsDict, get_instance_rate_per_hour, get_domain_for_region, + camel_case_to_pascal_case, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.user_agent import get_user_agent_extra_suffix @@ -555,11 +558,17 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: """Returns EULA message to display if one is available, else empty string.""" if model_specs.hosting_eula_key is None: return "" + return format_eula_message_from_specs( + model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key + ) + + +def format_eula_message_from_specs(model_id: str, region: str, hosting_eula_key: str): return ( - f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). " + f"Model '{model_id}' requires accepting end-user license agreement (EULA). " f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." f"{get_domain_for_region(region)}" - f"/{model_specs.hosting_eula_key} for terms of use." + f"/{hosting_eula_key} for terms of use." ) @@ -1525,3 +1534,41 @@ def wrapped_f(*args, **kwargs): if _func is None: return wrapper_cache return wrapper_cache(_func) + + +def _add_model_access_configs_to_model_data_sources( + model_data_sources: List[Dict[str, any]], + model_access_configs: List[ModelAccessConfig], + model_id: str, + region: str, +): + """Sets AcceptEula to True for gated speculative decoding models""" + + if not model_data_sources: + return model_data_sources + + acked_model_data_sources = [] + acked_model_access_configs = 0 + for model_data_source in model_data_sources: + hosting_eula_key = model_data_source.pop("HostingEulaKey", None) + if hosting_eula_key: + if not model_access_configs or acked_model_access_configs == len(model_access_configs): + eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}" + raise ValueError(eula_message_template.format( + model_source="Draft " if model_data_source.get("ChannelName") else "", + base_eula_message=format_eula_message_from_specs( + model_id=model_id, region=region, hosting_eula_key=hosting_eula_key + ), + model_access_configs_message=( + " Please add a ModelAccessConfig with AcceptEula=True" + " to model_access_configs to acknowledge the EULA." + ) + )) + acked_model_data_source = model_data_source.copy() + acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( + camel_case_to_pascal_case(model_access_configs[acked_model_access_configs].model_dump()) + ) + acked_model_data_sources.append(acked_model_data_source) + else: + acked_model_data_sources.append(model_data_source) + return acked_model_data_sources diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index ca9812c1d8..fbebf13591 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -505,7 +505,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): ) def set_deployment_config( - self, config_name: str, instance_type: str, accept_draft_model_eula: Optional[bool] = False + self, config_name: str, instance_type: str ) -> None: """Sets the deployment config to apply to the model. @@ -522,7 +522,7 @@ def set_deployment_config( if not hasattr(self, "pysdk_model") or self.pysdk_model is None: raise Exception("Cannot set deployment config to an uninitialized model.") - self.pysdk_model.set_deployment_config(config_name, instance_type, accept_draft_model_eula) + self.pysdk_model.set_deployment_config(config_name, instance_type) self.deployment_config_name = config_name self.instance_type = instance_type From b7b15b803d9ac01bbc779d475b79b336b9d688e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Tue, 12 Nov 2024 01:07:44 +0000 Subject: [PATCH 10/25] move the accept eula configurations into deploy flow --- src/sagemaker/jumpstart/factory/model.py | 2 +- src/sagemaker/jumpstart/model.py | 6 +++--- src/sagemaker/jumpstart/types.py | 2 +- src/sagemaker/jumpstart/utils.py | 18 +++++++++++------- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 73d5ad3f07..9513b3702a 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -655,7 +655,7 @@ def get_deploy_kwargs( training_config_name: Optional[str] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, - model_access_configs: Optional[List[ModelAccessConfig]] = None, + model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 05c6404501..8275ac4549 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -766,9 +766,9 @@ def deploy( (Default: EndpointType.MODEL_BASED). routing_config (Optional[Dict]): Settings the control how the endpoint routes incoming traffic to the instances that the endpoint hosts. - model_access_configs (Optional[List[ModelAccessConfig]]): For models that require Model Access Configs, - provide one or multiple ModelAccessConfig objects to indicate whether model terms of use have been accepted. - The `AcceptEula` value must be explicitly defined as `True` in order to + model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require ModelAccessConfig, + provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` to indicate whether model terms + of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some. (Default: None) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index c8edd45447..3eea06c5dc 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2294,7 +2294,7 @@ def __init__( endpoint_type: Optional[EndpointType] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, - model_access_configs: Optional[List[CoreModelAccessConfig]] = None + model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1d94fdf06d..d44d81be0c 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1538,7 +1538,7 @@ def wrapped_f(*args, **kwargs): def _add_model_access_configs_to_model_data_sources( model_data_sources: List[Dict[str, any]], - model_access_configs: List[ModelAccessConfig], + model_access_configs: Dict[str, ModelAccessConfig], model_id: str, region: str, ): @@ -1548,25 +1548,29 @@ def _add_model_access_configs_to_model_data_sources( return model_data_sources acked_model_data_sources = [] - acked_model_access_configs = 0 for model_data_source in model_data_sources: - hosting_eula_key = model_data_source.pop("HostingEulaKey", None) + hosting_eula_key = model_data_source.get("HostingEulaKey") if hosting_eula_key: - if not model_access_configs or acked_model_access_configs == len(model_access_configs): + if not model_access_configs or not model_access_configs.get(model_id): eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}" + model_access_config_entry = ( + "\"{model_id}\":ModelAccessConfig(accept_eula=True)".format(model_id=model_id) + ) raise ValueError(eula_message_template.format( model_source="Draft " if model_data_source.get("ChannelName") else "", base_eula_message=format_eula_message_from_specs( model_id=model_id, region=region, hosting_eula_key=hosting_eula_key ), model_access_configs_message=( - " Please add a ModelAccessConfig with AcceptEula=True" - " to model_access_configs to acknowledge the EULA." + " Please add a ModelAccessConfig entry:" + f" {model_access_config_entry} " + "to model_access_configs to acknowledge the EULA." ) )) acked_model_data_source = model_data_source.copy() + acked_model_data_source.pop("HostingEulaKey") acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( - camel_case_to_pascal_case(model_access_configs[acked_model_access_configs].model_dump()) + camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump()) ) acked_model_data_sources.append(acked_model_data_source) else: From 748ea4b690491246d43c2f0779dc746eed90b9c5 Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Tue, 12 Nov 2024 16:44:13 -0800 Subject: [PATCH 11/25] Use correct bucket for SM/JS draft models and minor formatting/validation updates. --- src/sagemaker/jumpstart/factory/model.py | 7 +-- src/sagemaker/jumpstart/model.py | 16 +++---- src/sagemaker/jumpstart/types.py | 13 +++--- src/sagemaker/jumpstart/utils.py | 43 +++++++++++++------ .../serve/builder/jumpstart_builder.py | 4 +- src/sagemaker/serve/builder/model_builder.py | 2 +- 6 files changed, 52 insertions(+), 33 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 9513b3702a..82bc1fc174 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -54,11 +54,11 @@ add_hub_content_arn_tags, add_jumpstart_model_info_tags, get_default_jumpstart_session_with_user_agent_suffix, - get_neo_content_bucket, get_top_ranked_config_name, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + get_draft_model_content_bucket, ) from sagemaker.jumpstart.factory.utils import ( @@ -76,7 +76,6 @@ name_from_base, format_tags, Tags, - get_domain_for_region, ) from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements @@ -572,7 +571,9 @@ def _add_additional_model_data_sources_to_kwargs( # Append speculative decoding data source from metadata speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() for data_source in speculative_decoding_data_sources: - data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region)) + data_source.s3_data_source.set_bucket( + get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region) + ) api_shape_additional_model_data_sources = ( [ camel_case_to_pascal_case(data_source.to_json()) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 8275ac4549..69be68fbfe 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -458,9 +458,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) - def set_deployment_config( - self, config_name: str, instance_type: str - ) -> None: + def set_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. Args: @@ -479,7 +477,7 @@ def set_deployment_config( instance_type=instance_type, config_name=config_name, sagemaker_session=self.sagemaker_session, - role=self.role + role=self.role, ) @property @@ -766,11 +764,11 @@ def deploy( (Default: EndpointType.MODEL_BASED). routing_config (Optional[Dict]): Settings the control how the endpoint routes incoming traffic to the instances that the endpoint hosts. - model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require ModelAccessConfig, - provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` to indicate whether model terms - of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to - accept the end-user license agreement (EULA) that some. - (Default: None) + model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require + ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` + to indicate whether model terms of use have been accepted. The `accept_eula` value + must be explicitly defined as `True` in order to accept the end-user license + agreement (EULA) that some. (Default: None) Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 3eea06c5dc..6be49c31ef 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1082,7 +1082,7 @@ def set_bucket(self, bucket: str) -> None: class AdditionalModelDataSource(JumpStartDataHolderType): """Data class of additional model data source mirrors CreateModel API.""" - SERIALIZATION_EXCLUSION_SET: Set[str] = set() + SERIALIZATION_EXCLUSION_SET = {"provider"} __slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"] @@ -1103,6 +1103,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.channel_name: str = json_obj["channel_name"] self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"]) self.hosting_eula_key: str = json_obj.get("hosting_eula_key") + self.provider: Dict = json_obj.get("provider", {}) def to_json(self, exclude_keys=True) -> Dict[str, Any]: """Returns json representation of AdditionalModelDataSource object.""" @@ -1121,7 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]: class JumpStartModelDataSource(AdditionalModelDataSource): """Data class JumpStart additional model data source.""" - SERIALIZATION_EXCLUSION_SET = {"artifact_version"} + SERIALIZATION_EXCLUSION_SET = { + "artifact_version" + } | AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__ @@ -2241,7 +2244,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "config_name", "routing_config", "specs", - "model_access_configs" + "model_access_configs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2255,7 +2258,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "sagemaker_session", "training_instance_type", "config_name", - "model_access_configs" + "model_access_configs", } def __init__( @@ -2294,7 +2297,7 @@ def __init__( endpoint_type: Optional[EndpointType] = None, config_name: Optional[str] = None, routing_config: Optional[Dict[str, Any]] = None, - model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None + model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index d44d81be0c..f7339d058a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -564,6 +564,7 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: def format_eula_message_from_specs(model_id: str, region: str, hosting_eula_key: str): + """Returns a formatted EULA message.""" return ( f"Model '{model_id}' requires accepting end-user license agreement (EULA). " f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." @@ -1552,21 +1553,25 @@ def _add_model_access_configs_to_model_data_sources( hosting_eula_key = model_data_source.get("HostingEulaKey") if hosting_eula_key: if not model_access_configs or not model_access_configs.get(model_id): - eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}" + eula_message_template = ( + "{model_source}{base_eula_message}{model_access_configs_message}" + ) model_access_config_entry = ( - "\"{model_id}\":ModelAccessConfig(accept_eula=True)".format(model_id=model_id) + '"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id) ) - raise ValueError(eula_message_template.format( - model_source="Draft " if model_data_source.get("ChannelName") else "", - base_eula_message=format_eula_message_from_specs( - model_id=model_id, region=region, hosting_eula_key=hosting_eula_key - ), - model_access_configs_message=( - " Please add a ModelAccessConfig entry:" - f" {model_access_config_entry} " - "to model_access_configs to acknowledge the EULA." + raise ValueError( + eula_message_template.format( + model_source="Draft " if model_data_source.get("ChannelName") else "", + base_eula_message=format_eula_message_from_specs( + model_id=model_id, region=region, hosting_eula_key=hosting_eula_key + ), + model_access_configs_message=( + " Please add a ModelAccessConfig entry:" + f" {model_access_config_entry} " + "to model_access_configs to acknowledge the EULA." + ), ) - )) + ) acked_model_data_source = model_data_source.copy() acked_model_data_source.pop("HostingEulaKey") acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( @@ -1576,3 +1581,17 @@ def _add_model_access_configs_to_model_data_sources( else: acked_model_data_sources.append(model_data_source) return acked_model_data_sources + + +def get_draft_model_content_bucket(provider: Dict, region: str) -> str: + """Returns the correct content bucket for a 1p draft model.""" + neo_bucket = get_neo_content_bucket(region=region) + if not provider: + return neo_bucket + provider_name = provider.get("name", "") + if provider_name == "JumpStart": + classification = provider.get("classification", "ungated") + if classification == "gated": + return get_jumpstart_gated_content_bucket(region=region) + return get_jumpstart_content_bucket(region=region) + return neo_bucket diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index fbebf13591..6ef7992382 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -504,9 +504,7 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) - def set_deployment_config( - self, config_name: str, instance_type: str - ) -> None: + def set_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. Args: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index a331b606e3..e2cbd5c795 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1344,7 +1344,7 @@ def _optimize_for_hf( Optional[Dict[str, Any]]: Model optimization job input arguments. """ if speculative_decoding_config: - if speculative_decoding_config.get("ModelProvider", "") == "JumpStart": + if speculative_decoding_config.get("ModelProvider", "").lower() == "jumpstart": _jumpstart_speculative_decoding( model=self.pysdk_model, speculative_decoding_config=speculative_decoding_config, From a7feb549a74f06ed266170bdc8bc01feee1289c0 Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Tue, 12 Nov 2024 16:50:36 -0800 Subject: [PATCH 12/25] Remove obsolete docstring. --- src/sagemaker/serve/builder/jumpstart_builder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 6ef7992382..e1c9d0eb8e 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -514,8 +514,6 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: instance_type (str): The instance_type that the model will use after setting the config. - accept_draft_model_eula (Optional[bool]): - If the config selected comes with a gated additional model data source. """ if not hasattr(self, "pysdk_model") or self.pysdk_model is None: raise Exception("Cannot set deployment config to an uninitialized model.") From 694b4f2f3886a74a4ebb25b89bf9717ecc7ff36c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Wed, 13 Nov 2024 18:13:11 +0000 Subject: [PATCH 13/25] remove references to accept_draft_model_eula --- src/sagemaker/jumpstart/model.py | 2 -- src/sagemaker/serve/builder/jumpstart_builder.py | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 69be68fbfe..1cfaa0a709 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -468,8 +468,6 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: instance_type (str): The instance_type that the model will use after setting the config. - accept_draft_model_eula (Optional[bool]): - If the config selected comes with a gated additional model data source. """ self.__init__( model_id=self.model_id, diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index e1c9d0eb8e..4b68fa20f3 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -844,7 +844,6 @@ def _set_additional_model_source( """ if speculative_decoding_config: model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) - accept_draft_model_eula = speculative_decoding_config.get("AcceptEula", False) channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) @@ -874,7 +873,6 @@ def _set_additional_model_source( self.pysdk_model.set_deployment_config( config_name=deployment_config.get("DeploymentConfigName"), instance_type=deployment_config.get("InstanceType"), - accept_draft_model_eula=accept_draft_model_eula, ) except ValueError as e: raise ValueError( @@ -909,7 +907,7 @@ def _set_additional_model_source( ) else: self.pysdk_model = _custom_speculative_decoding( - self.pysdk_model, speculative_decoding_config, accept_draft_model_eula + self.pysdk_model, speculative_decoding_config, speculative_decoding_config.get("AcceptEula", False) ) def _find_compatible_deployment_config( From 7b6aef1a4000efd3059d5b4cd7b3982ffdacff3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Wed, 13 Nov 2024 18:31:08 +0000 Subject: [PATCH 14/25] renaming of eula fn and error msg --- src/sagemaker/jumpstart/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f7339d058a..3e574cc383 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -558,12 +558,12 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: """Returns EULA message to display if one is available, else empty string.""" if model_specs.hosting_eula_key is None: return "" - return format_eula_message_from_specs( + return format_eula_message_template( model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key ) -def format_eula_message_from_specs(model_id: str, region: str, hosting_eula_key: str): +def format_eula_message_template(model_id: str, region: str, hosting_eula_key: str): """Returns a formatted EULA message.""" return ( f"Model '{model_id}' requires accepting end-user license agreement (EULA). " @@ -1561,8 +1561,8 @@ def _add_model_access_configs_to_model_data_sources( ) raise ValueError( eula_message_template.format( - model_source="Draft " if model_data_source.get("ChannelName") else "", - base_eula_message=format_eula_message_from_specs( + model_source="Additional " if model_data_source.get("ChannelName") else "", + base_eula_message=format_eula_message_template( model_id=model_id, region=region, hosting_eula_key=hosting_eula_key ), model_access_configs_message=( From 1f75072145bfcfe84aeff552abd3505992ef331e Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:06:09 -0800 Subject: [PATCH 15/25] fix: pin testing deps (#4925) Co-authored-by: nileshvd <113946607+nileshvd@users.noreply.github.com> --- requirements/extras/test_requirements.txt | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index a2c0fbfc65..1592576a47 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -22,26 +22,30 @@ requests==2.32.2 sagemaker-experiments==0.1.35 Jinja2==3.1.4 pyvis==0.2.1 -pandas>=1.3.5,<1.5 +pandas==1.4.4 scikit-learn==1.3.0 cloudpickle>=2.2.1 +jsonpickle<4.0.0 PyYAML==6.0 # TODO find workaround xgboost>=1.6.2,<=1.7.6 pillow>=10.0.1,<=11 -transformers>=4.36.0 +opentelemetry-proto==1.27.0 +protobuf==4.25.5 +tensorboard>=2.9.0,<=2.15.2 +transformers==4.46.1 sentencepiece==0.1.99 # https://github.com/triton-inference-server/server/issues/6246 tritonclient[http]<2.37.0 -onnx>=1.15.0 +onnx==1.17.0 # tf2onnx==1.15.1 nbformat>=5.9,<6 accelerate>=0.24.1,<=0.27.0 schema==0.7.5 -tensorflow>=2.1,<=2.16 +tensorflow>=2.9.0,<=2.15.1 mlflow>=2.12.2,<2.13 -huggingface_hub>=0.23.4 +huggingface_hub==0.26.2 uvicorn>=0.30.1 -fastapi>=0.111.0 +fastapi==0.115.4 nest-asyncio sagemaker-mlflow>=0.1.0 From 277e0b1c729a7edcc10990e86d6c4aaff94b5563 Mon Sep 17 00:00:00 2001 From: Haotian An <33510317+Captainia@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:14:52 -0500 Subject: [PATCH 16/25] Revert "change: add TGI 2.4.0 image uri (#4922)" (#4926) --- CONTRIBUTING.md | 2 +- .../image_uri_config/huggingface-llm.json | 47 ------------------- .../image_uris/test_huggingface_llm.py | 1 - 3 files changed, 1 insertion(+), 49 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b2bcf44cd1..24226af4ee 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -76,7 +76,7 @@ Before sending us a pull request, please ensure that: 1. Install tox using `pip install tox` 1. Install coverage using `pip install .[test]` 1. cd into the sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` -1. Run the following tox command and verify that all code checks and unit tests pass: `tox -- tests/unit` +1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit` 1. You can also run a single test with the following command: `tox -e py310 -- -s -vv ::` 1. You can run coverage via runcvoerage env : `tox -e runcoverage -- tests/unit` or `tox -e py310 -- tests/unit --cov=sagemaker --cov-append --cov-report xml` * Note that the coverage test will fail if you only run a single test, so make sure to surround the command with `export IGNORE_COVERAGE=-` and `unset IGNORE_COVERAGE` diff --git a/src/sagemaker/image_uri_config/huggingface-llm.json b/src/sagemaker/image_uri_config/huggingface-llm.json index 42f160eff1..24cbd5ca96 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm.json +++ b/src/sagemaker/image_uri_config/huggingface-llm.json @@ -766,53 +766,6 @@ "container_version": { "gpu": "cu124-ubuntu22.04" } - }, - "2.4.0": { - "py_versions": [ - "py311" - ], - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ap-southeast-4": "457447274322", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", - "me-central-1": "914824155844", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "tag_prefix": "2.4.0-tgi2.4.0", - "repository": "huggingface-pytorch-tgi-inference", - "container_version": { - "gpu": "cu124-ubuntu22.04" - } } } } diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index d993979cfd..28525a390c 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -46,7 +46,6 @@ "2.0.2": "2.3.0-tgi2.0.2-gpu-py310-cu121-ubuntu22.04", "2.2.0": "2.3.0-tgi2.2.0-gpu-py310-cu121-ubuntu22.04-v2.0", "2.3.1": "2.4.0-tgi2.3.1-gpu-py311-cu124-ubuntu22.04", - "2.4.0": "2.4.0-tgi2.4.0-gpu-py311-cu124-ubuntu22.04", }, "inf2": { "0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04", From 8f0083b65c52337a431881a772cad8099b1db1e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Thu, 14 Nov 2024 01:47:23 +0000 Subject: [PATCH 17/25] fix naming and messaging --- src/sagemaker/jumpstart/model.py | 4 ++-- src/sagemaker/jumpstart/utils.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 1cfaa0a709..65bb156ee3 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -665,7 +665,7 @@ def deploy( managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, routing_config: Optional[Dict[str, Any]] = None, - model_access_configs: Optional[List[ModelAccessConfig]] = None, + model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -766,7 +766,7 @@ def deploy( ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` to indicate whether model terms of use have been accepted. The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license - agreement (EULA) that some. (Default: None) + agreement (EULA) that some models require. (Default: None) Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 3e574cc383..a270b11915 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -558,12 +558,12 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: """Returns EULA message to display if one is available, else empty string.""" if model_specs.hosting_eula_key is None: return "" - return format_eula_message_template( + return get_formatted_eula_message_template( model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key ) -def format_eula_message_template(model_id: str, region: str, hosting_eula_key: str): +def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str: """Returns a formatted EULA message.""" return ( f"Model '{model_id}' requires accepting end-user license agreement (EULA). " @@ -1562,13 +1562,13 @@ def _add_model_access_configs_to_model_data_sources( raise ValueError( eula_message_template.format( model_source="Additional " if model_data_source.get("ChannelName") else "", - base_eula_message=format_eula_message_template( + base_eula_message=get_formatted_eula_message_template( model_id=model_id, region=region, hosting_eula_key=hosting_eula_key ), model_access_configs_message=( - " Please add a ModelAccessConfig entry:" + "Please add a ModelAccessConfig entry:" f" {model_access_config_entry} " - "to model_access_configs to acknowledge the EULA." + "to model_access_configs to accept the EULA." ), ) ) From 8b73f3482076f3bb87dbec2b83b5db14fe6f85d4 Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Wed, 13 Nov 2024 18:15:18 -0800 Subject: [PATCH 18/25] ModelBuilder speculative decoding UTs and minor fixes. --- src/sagemaker/jumpstart/types.py | 4 +- .../serve/builder/jumpstart_builder.py | 8 +- src/sagemaker/serve/builder/model_builder.py | 2 +- src/sagemaker/serve/utils/optimize_utils.py | 23 ++- .../serve/builder/test_js_builder.py | 37 ++++- tests/unit/sagemaker/serve/constants.py | 48 ++++++ .../serve/utils/test_optimize_utils.py | 156 +++++++++++++++++- 7 files changed, 245 insertions(+), 33 deletions(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 6be49c31ef..fab71679e8 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1122,9 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]: class JumpStartModelDataSource(AdditionalModelDataSource): """Data class JumpStart additional model data source.""" - SERIALIZATION_EXCLUSION_SET = { + SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union( "artifact_version" - } | AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET + ) __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__ diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 4b68fa20f3..7d6a052023 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -737,9 +737,7 @@ def _optimize_for_jumpstart( if not optimization_config: optimization_config = {} - if ( - not optimization_config or not optimization_config.get("ModelCompilationConfig") - ) and is_compilation: + if not optimization_config.get("ModelCompilationConfig") and is_compilation: # Fallback to default if override_env is None or empty if not compilation_override_env: compilation_override_env = pysdk_model_env_vars @@ -907,7 +905,9 @@ def _set_additional_model_source( ) else: self.pysdk_model = _custom_speculative_decoding( - self.pysdk_model, speculative_decoding_config, speculative_decoding_config.get("AcceptEula", False) + self.pysdk_model, + speculative_decoding_config, + speculative_decoding_config.get("AcceptEula", False), ) def _find_compatible_deployment_config( diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index e2cbd5c795..501574c60e 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -591,7 +591,7 @@ def _model_builder_deploy_wrapper( ) if "endpoint_logging" not in kwargs: - kwargs["endpoint_logging"] = True + kwargs["endpoint_logging"] = False predictor = self._original_deploy( *args, instance_type=instance_type, diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 1859e6b589..14df6b3639 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -73,9 +73,8 @@ def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) - return False deployment_args = deployment_config.get("DeploymentArgs", {}) additional_data_sources = deployment_args.get("AdditionalDataSources") - if not additional_data_sources: - return False - return additional_data_sources.get("speculative_decoding", False) + + return "speculative_decoding" in additional_data_sources if additional_data_sources else False def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool: @@ -207,15 +206,15 @@ def _extract_speculative_draft_model_provider( if speculative_decoding_config is None: return None - if speculative_decoding_config.get("ModelProvider").lower() == "jumpstart": + model_provider = speculative_decoding_config.get("ModelProvider", "").lower() + + if model_provider == "jumpstart": return "jumpstart" - if speculative_decoding_config.get( - "ModelProvider" - ).lower() == "custom" or speculative_decoding_config.get("ModelSource"): + if model_provider == "custom" or speculative_decoding_config.get("ModelSource"): return "custom" - if speculative_decoding_config.get("ModelProvider").lower() == "sagemaker": + if model_provider == "sagemaker": return "sagemaker" return "auto" @@ -238,7 +237,7 @@ def _extract_additional_model_data_source_s3_uri( ): return None - return additional_model_data_source.get("S3DataSource").get("S3Uri", None) + return additional_model_data_source.get("S3DataSource").get("S3Uri") def _extract_deployment_config_additional_model_data_source_s3_uri( @@ -272,7 +271,7 @@ def _is_draft_model_gated( Returns: bool: Whether the draft model is gated or not. """ - return draft_model_config.get("hosting_eula_key", None) + return "hosting_eula_key" in draft_model_config if draft_model_config else False def _extracts_and_validates_speculative_model_source( @@ -371,7 +370,7 @@ def _extract_optimization_config_and_env( compilation_config (Optional[Dict]): The compilation config. Returns: - Optional[Tuple[Optional[Dict], Optional[Dict]]]: + Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: The optimization config and environment variables. """ optimization_config = {} @@ -388,7 +387,7 @@ def _extract_optimization_config_and_env( if compilation_config is not None: optimization_config["ModelCompilationConfig"] = compilation_config - # Return both dicts and environment variable if either is present + # Return optimization config dict and environment variables if either is present if optimization_config: return optimization_config, quantization_override_env, compilation_override_env diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 265907db45..25bc67d22d 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -28,6 +28,7 @@ from tests.unit.sagemaker.serve.constants import ( DEPLOYMENT_CONFIGS, OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, + CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES, ) mock_model_id = "huggingface-llm-amazon-falconlite" @@ -1203,19 +1204,34 @@ def test_optimize_quantize_for_jumpstart( @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) - def test_optimize_gated_draft_model_for_jumpstart_with_accept_eula_false( + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._jumpstart_speculative_decoding", + return_value=True, + ) + def test_jumpstart_model_provider_calls_jumpstart_speculative_decoding( self, + mock_js_speculative_decoding, + mock_pretrained_js_model, + mock_is_js_model, mock_serve_settings, - mock_telemetry, + mock_capture_telemetry, ): mock_sagemaker_session = Mock() - mock_pysdk_model = Mock() mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} mock_pysdk_model.model_data = mock_model_data mock_pysdk_model.image_uri = mock_tgi_image_uri mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS mock_pysdk_model.deployment_config = OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL + mock_pysdk_model.additional_model_data_sources = CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES sample_input = { "inputs": "The diamondback terrapin or simply terrapin is a species " @@ -1238,14 +1254,17 @@ def test_optimize_gated_draft_model_for_jumpstart_with_accept_eula_false( model_builder.pysdk_model = mock_pysdk_model - self.assertRaises( - ValueError, - model_builder._optimize_for_jumpstart( - accept_eula=True, - speculative_decoding_config={"Provider": "sagemaker", "AcceptEula": False}, - ), + model_builder._optimize_for_jumpstart( + accept_eula=True, + speculative_decoding_config={ + "ModelProvider": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": False, + }, ) + mock_js_speculative_decoding.assert_called_once() + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) def test_optimize_quantize_and_compile_for_jumpstart( diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index c473750411..3e776eaa46 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -165,6 +165,43 @@ }, }, ] +NON_OPTIMIZED_DEPLOYMENT_CONFIG = { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, +} OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL = { "DeploymentConfigName": "lmi-optimized", "DeploymentArgs": { @@ -267,3 +304,14 @@ "sagemaker-speculative-decoding-llama3-small-v3/", }, } +CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "CompressionType": "None", + "S3DataType": "S3Prefix", + }, + } +] diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 8b11b40060..7cf0406f42 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -32,10 +32,14 @@ _custom_speculative_decoding, _is_inferentia_or_trainium, _is_draft_model_gated, + _deployment_config_contains_draft_model, + _jumpstart_speculative_decoding, ) from tests.unit.sagemaker.serve.constants import ( GATED_DRAFT_MODEL_CONFIG, NON_GATED_DRAFT_MODEL_CONFIG, + OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, + NON_OPTIMIZED_DEPLOYMENT_CONFIG, ) mock_optimization_job_output = { @@ -185,6 +189,9 @@ def test_update_environment_variables(env, new_env, output_env): ({"ModelProvider": "SageMaker"}, "sagemaker"), ({"ModelProvider": "Custom"}, "custom"), ({"ModelSource": "s3://"}, "custom"), + ({"ModelProvider": "JumpStart"}, "jumpstart"), + ({"ModelProvider": "asdf"}, "auto"), + ({"ModelProvider": "Auto"}, "auto"), (None, None), ], ) @@ -229,7 +236,7 @@ def test_generate_additional_model_data_sources(): "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "S3DataType": "S3Prefix", "CompressionType": "None", - "ModelAccessConfig": {"ACCEPT_EULA": True}, + "ModelAccessConfig": {"AcceptEula": True}, }, } ] @@ -268,12 +275,12 @@ def test_is_s3_uri(s3_uri, expected): @pytest.mark.parametrize( "draft_model_config, expected", [ - (GATED_DRAFT_MODEL_CONFIG, NON_GATED_DRAFT_MODEL_CONFIG), - (True, False), + (GATED_DRAFT_MODEL_CONFIG, True), + (NON_GATED_DRAFT_MODEL_CONFIG, False), ], ) def test_is_draft_model_gated(draft_model_config, expected): - assert _is_draft_model_gated(draft_model_config, expected) + assert _is_draft_model_gated(draft_model_config) is expected @pytest.mark.parametrize( @@ -334,6 +341,145 @@ def test_extract_optimization_config_and_env( ) +@pytest.mark.parametrize( + "deployment_config", + [ + (OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, True), + (NON_OPTIMIZED_DEPLOYMENT_CONFIG, False), + (None, False), + ], +) +def deployment_config_contains_draft_model(deployment_config, expected): + assert _deployment_config_contains_draft_model(deployment_config) + + +class TestJumpStartSpeculativeDecodingConfig(unittest.TestCase): + + @patch("sagemaker.model.Model") + def test_with_no_js_model_id(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = {"ModelSource": "JumpStart"} + + with self.assertRaises(ValueError) as _: + _jumpstart_speculative_decoding(mock_model, speculative_decoding_config) + + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket", + return_value="js_gated_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket", + return_value="js_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs", + return_value=Mock(), + ) + @patch("sagemaker.model.Model") + def test_with_gated_js_model( + self, + mock_model, + mock_model_specs, + mock_js_content_bucket, + mock_js_gated_content_bucket, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.boto_region_name = "us-west-2" + + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": True, + } + + mock_model_specs.return_value.to_json.return_value = { + "gated_bucket": True, + "hosting_prepacked_artifact_key": "hosting_prepacked_artifact_key", + } + + _jumpstart_speculative_decoding( + mock_model, speculative_decoding_config, mock_sagemaker_session + ) + + expected_env_var = { + "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/" + } + self.maxDiff = None + + self.assertEqual( + mock_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": f"s3://{mock_js_gated_content_bucket.return_value}/hosting_prepacked_artifact_key", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ], + ) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"} + ) + self.assertEqual(mock_model.env, expected_env_var) + + @patch( + "sagemaker.serve.utils.optimize_utils.get_eula_message", return_value="Accept eula message" + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket", + return_value="js_gated_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket", + return_value="js_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs", + return_value=Mock(), + ) + @patch("sagemaker.model.Model") + def test_with_gated_js_model_and_accept_eula_false( + self, + mock_model, + mock_model_specs, + mock_js_content_bucket, + mock_js_gated_content_bucket, + mock_eula_message, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.boto_region_name = "us-west-2" + + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": False, + } + + mock_model_specs.return_value.to_json.return_value = { + "gated_bucket": True, + "hosting_prepacked_artifact_key": "hosting_prepacked_artifact_key", + } + + self.assertRaisesRegex( + ValueError, + f"{mock_eula_message.return_value} Set `AcceptEula`=True in " + f"speculative_decoding_config once acknowledged.", + _jumpstart_speculative_decoding, + mock_model, + speculative_decoding_config, + mock_sagemaker_session, + ) + + class TestCustomSpeculativeDecodingConfig(unittest.TestCase): @patch("sagemaker.model.Model") @@ -387,7 +533,7 @@ def test_with_s3_js(self, mock_model): "S3Uri": "s3://bucket/huggingface-pytorch-tgi-inference", "S3DataType": "S3Prefix", "CompressionType": "None", - "ModelAccessConfig": {"ACCEPT_EULA": True}, + "ModelAccessConfig": {"AcceptEula": True}, }, } ], From 09a54dc1cc101d033bc4f223979c93d67525e98c Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Thu, 14 Nov 2024 11:14:45 -0800 Subject: [PATCH 19/25] Fix set union. --- src/sagemaker/jumpstart/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index fab71679e8..cb989ca4d4 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1123,7 +1123,7 @@ class JumpStartModelDataSource(AdditionalModelDataSource): """Data class JumpStart additional model data source.""" SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union( - "artifact_version" + {"artifact_version"} ) __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__ From 3b147cd5b916717656aea6dc34a46b860d83166a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Fri, 15 Nov 2024 00:49:22 +0000 Subject: [PATCH 20/25] add UTs for JumpStart deployment --- src/sagemaker/jumpstart/model.py | 2 + src/sagemaker/jumpstart/utils.py | 18 +- .../sagemaker/jumpstart/model/test_model.py | 92 +++++++- tests/unit/sagemaker/jumpstart/test_utils.py | 217 ++++++++++++++++++ tests/unit/sagemaker/jumpstart/utils.py | 24 ++ 5 files changed, 345 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 65bb156ee3..0f916e5ff6 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -817,12 +817,14 @@ def deploy( f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." ) + print(self.additional_model_data_sources) self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources( self.additional_model_data_sources, deploy_kwargs.model_access_configs, deploy_kwargs.model_id, deploy_kwargs.region, ) + print(self.additional_model_data_sources) try: predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a270b11915..0a9fd9967d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1544,6 +1544,7 @@ def _add_model_access_configs_to_model_data_sources( region: str, ): """Sets AcceptEula to True for gated speculative decoding models""" + print(model_data_sources) if not model_data_sources: return model_data_sources @@ -1551,8 +1552,13 @@ def _add_model_access_configs_to_model_data_sources( acked_model_data_sources = [] for model_data_source in model_data_sources: hosting_eula_key = model_data_source.get("HostingEulaKey") + mutable_model_data_source = model_data_source.copy() if hosting_eula_key: - if not model_access_configs or not model_access_configs.get(model_id): + if ( + not model_access_configs + or not model_access_configs.get(model_id) + or not model_access_configs.get(model_id).accept_eula + ): eula_message_template = ( "{model_source}{base_eula_message}{model_access_configs_message}" ) @@ -1572,14 +1578,14 @@ def _add_model_access_configs_to_model_data_sources( ), ) ) - acked_model_data_source = model_data_source.copy() - acked_model_data_source.pop("HostingEulaKey") - acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( + mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is applied + mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump()) ) - acked_model_data_sources.append(acked_model_data_source) + acked_model_data_sources.append(mutable_model_data_source) else: - acked_model_data_sources.append(model_data_source) + mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is not applicable + acked_model_data_sources.append(mutable_model_data_source) return acked_model_data_sources diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 4c573bca8c..4defe33929 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -19,6 +19,7 @@ import pandas as pd from mock import MagicMock, Mock import pytest +from sagemaker_core.shapes import ModelAccessConfig from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.jumpstart.artifacts.environment_variables import ( _retrieve_default_environment_variables, @@ -54,6 +55,7 @@ get_base_deployment_configs, get_base_spec_with_prototype_configs_with_missing_benchmarks, append_instance_stat_metrics, + append_gated_draft_model_specs_to_jumpstart_model_spec, ) import boto3 @@ -772,6 +774,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): init_args_to_skip: Set[str] = set(["model_reference_arn"]) deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"]) + deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"]) parent_class_init = Model.__init__ parent_class_init_args = set(signature(parent_class_init).parameters.keys()) @@ -798,8 +801,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self): js_class_deploy = JumpStartModel.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == set() - assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip + assert js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time == set() + assert (parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time == + deploy_args_to_skip) @mock.patch( "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} @@ -1762,6 +1766,89 @@ def test_model_set_deployment_config( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_and_deploy_for_gated_draft_model( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + # WHERE + mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: + get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id = "pytorch-eqa-bert-base-cased" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + # WHEN + model.deploy(model_access_configs={"pytorch-eqa-bert-base-cased":ModelAccessConfig(accept_eula=True)}) + + # THEN + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + # WHERE + mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: + get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id = "pytorch-eqa-bert-base-cased" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + # WHEN / THEN + with self.assertRaises(ValueError): + model.deploy() + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch( "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" @@ -1810,6 +1897,7 @@ def test_model_deployment_config_additional_model_data_source( "S3Uri": "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/", "ModelAccessConfig": {"AcceptEula": False}, }, + "HostingEulaKey": None, } ], ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fe2ba749cd..8f85cf8514 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -20,6 +20,7 @@ import pytest import boto3 import random +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import session from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( @@ -2168,3 +2169,219 @@ def test_get_domain_for_region(self): " https://jumpstart-cache-prod-cn-north-1.s3.cn-north-1.amazonaws.com.cn/some-eula-key " "for terms of use.", ) + + +class TestAcceptEulaModelAccessConfig(TestCase): + MOCK_PUBLIC_MODEL_ID = "mock_public_model_id" + MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ + { + 'ChannelName': 'draft_model', + 'S3DataSource': { + 'CompressionType': 'None', + 'S3DataType': 'S3Prefix', + 'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/' + }, + 'HostingEulaKey': None + } + ] + MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ + { + 'ChannelName': 'draft_model', + 'S3DataSource': { + 'CompressionType': 'None', + 'S3DataType': 'S3Prefix', + 'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/' + } + } + ] + MOCK_GATED_MODEL_ID = "mock_gated_model_id" + MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ + { + 'ChannelName': 'draft_model', + 'S3DataSource': { + 'CompressionType': 'None', + 'S3DataType': 'S3Prefix', + 'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/' + }, + 'HostingEulaKey': "fmhMetadata/eula/llama3_2Eula.txt" + } + ] + MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ + { + 'ChannelName': 'draft_model', + 'S3DataSource': { + 'CompressionType': 'None', + 'S3DataType': 'S3Prefix', + 'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/', + 'ModelAccessConfig': { + "AcceptEula": True + } + } + } + ] + + # Public Positive Cases + + def test_public_additional_model_data_source_should_pass_through(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + def test_multiple_public_additional_model_data_source_should_pass_through_both(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_public_additional_model_data_source_with_model_access_config_should_ignored_it(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + def test_no_additional_model_data_source_should_pass_through(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=None, + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert not additional_model_data_sources + + # Gated Positive Cases + + def test_gated_additional_model_data_source_should_accept_it(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + def test_multiple_gated_additional_model_data_source_should_accept_both(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs={ + self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True), + self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + # Mixed Positive Cases + + def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs={ + self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + # Test Gated Negative Tests + + def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error(self): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs=None, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error(self): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs=None, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error(self): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_PUBLIC_MODEL_ID:ModelAccessConfig(accept_eula=True) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error(self): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=False) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index de274f0374..74489824be 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -32,6 +32,7 @@ DeploymentConfigMetadata, JumpStartModelDeployKwargs, JumpStartBenchmarkStat, + JumpStartAdditionalDataSources, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -436,3 +437,26 @@ def append_instance_stat_metrics( ) ) return metrics + + +def append_gated_draft_model_specs_to_jumpstart_model_spec(*args, **kwargs): + augmented_spec = get_prototype_model_spec(*args, **kwargs) + augmented_spec.hosting_additional_data_sources = JumpStartAdditionalDataSources(spec={ + 'speculative_decoding': [ + { + 'channel_name': 'draft_model', + 'provider': { + 'name': 'JumpStart', + 'classification': 'gated' + }, + 'artifact_version': 'v1', + 'hosting_eula_key': 'fmhMetadata/eula/llama3_2Eula.txt', + 's3_data_source': { + 's3_uri': 'meta-textgeneration/meta-textgeneration-llama-3-2-1b-instruct/artifacts/inference-prepack/v1.0.0/', + 'compression_type': 'None', + 's3_data_type': 'S3Prefix' + } + } + ] + }) + return augmented_spec From 65cb5b3ee80ff39ed7461497a8d962a0c33158f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Fri, 15 Nov 2024 01:03:51 +0000 Subject: [PATCH 21/25] fix formatting issues --- src/sagemaker/jumpstart/utils.py | 8 +- .../check_optimization_configurations.py | 38 ------ .../sagemaker/jumpstart/model/test_model.py | 47 ++++--- tests/unit/sagemaker/jumpstart/test_utils.py | 123 ++++++++++-------- tests/unit/sagemaker/jumpstart/utils.py | 35 ++--- 5 files changed, 117 insertions(+), 134 deletions(-) delete mode 100644 src/sagemaker/serve/validations/check_optimization_configurations.py diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 0a9fd9967d..15a1373c50 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1578,13 +1578,17 @@ def _add_model_access_configs_to_model_data_sources( ), ) ) - mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is applied + mutable_model_data_source.pop( + "HostingEulaKey" + ) # pop when model access config is applied mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump()) ) acked_model_data_sources.append(mutable_model_data_source) else: - mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is not applicable + mutable_model_data_source.pop( + "HostingEulaKey" + ) # pop when model access config is not applicable acked_model_data_sources.append(mutable_model_data_source) return acked_model_data_sources diff --git a/src/sagemaker/serve/validations/check_optimization_configurations.py b/src/sagemaker/serve/validations/check_optimization_configurations.py deleted file mode 100644 index 82477d02a1..0000000000 --- a/src/sagemaker/serve/validations/check_optimization_configurations.py +++ /dev/null @@ -1,38 +0,0 @@ -SUPPORTED_OPTIMIZATION_CONFIGURATIONS = { - "trt": { - "supported_instance_families": ["p4d", "p4de", "p5", "g5", "g6"], - "compilation": True, - "quantization": { - "awq": True, - "fp8": True, - "gptq": False, - "smooth_quant": True - }, - "speculative_decoding": False, - "sharding": False - }, - "vllm": { - "supported_instance_families": ["p4d", "p4de", "p5", "g5", "g6"], - "compilation": False, - "quantization": { - "awq": True, - "fp8": True, - "gptq": False, - "smooth_quant": False - }, - "speculative_decoding": True, - "sharding": True - }, - "neuron": { - "supported_instance_families": ["inf2", "trn1", "trn1n"], - "compilation": True, - "quantization": { - "awq": False, - "fp8": False, - "gptq": False, - "smooth_quant": False - }, - "speculative_decoding": False, - "sharding": False - } -} diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 4defe33929..09e63f8f59 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -801,9 +801,14 @@ def test_jumpstart_model_kwargs_match_parent_class(self): js_class_deploy = JumpStartModel.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time == set() - assert (parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time == - deploy_args_to_skip) + assert ( + js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time + == set() + ) + assert ( + parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time + == deploy_args_to_skip + ) @mock.patch( "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} @@ -1775,18 +1780,17 @@ def test_model_set_deployment_config( @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_set_deployment_config_and_deploy_for_gated_draft_model( - self, - mock_model_deploy: mock.Mock, - mock_get_model_specs: mock.Mock, - mock_session: mock.Mock, - mock_get_manifest: mock.Mock, - mock_get_jumpstart_configs: mock.Mock, + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): # WHERE mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec mock_get_manifest.side_effect = ( - lambda region, model_type, *args, **kwargs: - get_prototype_manifest(region, model_type) + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) mock_model_deploy.return_value = default_predictor @@ -1799,7 +1803,11 @@ def test_model_set_deployment_config_and_deploy_for_gated_draft_model( assert model.config_name is None # WHEN - model.deploy(model_access_configs={"pytorch-eqa-bert-base-cased":ModelAccessConfig(accept_eula=True)}) + model.deploy( + model_access_configs={ + "pytorch-eqa-bert-base-cased": ModelAccessConfig(accept_eula=True) + } + ) # THEN mock_model_deploy.assert_called_once_with( @@ -1822,18 +1830,17 @@ def test_model_set_deployment_config_and_deploy_for_gated_draft_model( @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs( - self, - mock_model_deploy: mock.Mock, - mock_get_model_specs: mock.Mock, - mock_session: mock.Mock, - mock_get_manifest: mock.Mock, - mock_get_jumpstart_configs: mock.Mock, + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): # WHERE mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec mock_get_manifest.side_effect = ( - lambda region, model_type, *args, **kwargs: - get_prototype_manifest(region, model_type) + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) mock_model_deploy.return_value = default_predictor diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 8f85cf8514..a056c315ab 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -2175,48 +2175,46 @@ class TestAcceptEulaModelAccessConfig(TestCase): MOCK_PUBLIC_MODEL_ID = "mock_public_model_id" MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ { - 'ChannelName': 'draft_model', - 'S3DataSource': { - 'CompressionType': 'None', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/' + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/public/resources/", }, - 'HostingEulaKey': None + "HostingEulaKey": None, } ] MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ { - 'ChannelName': 'draft_model', - 'S3DataSource': { - 'CompressionType': 'None', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/' - } + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/public/resources/", + }, } ] MOCK_GATED_MODEL_ID = "mock_gated_model_id" MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ { - 'ChannelName': 'draft_model', - 'S3DataSource': { - 'CompressionType': 'None', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/' + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/", }, - 'HostingEulaKey': "fmhMetadata/eula/llama3_2Eula.txt" + "HostingEulaKey": "fmhMetadata/eula/llama3_2Eula.txt", } ] MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ { - 'ChannelName': 'draft_model', - 'S3DataSource': { - 'CompressionType': 'None', - 'S3DataType': 'S3Prefix', - 'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/', - 'ModelAccessConfig': { - "AcceptEula": True - } - } + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/", + "ModelAccessConfig": {"AcceptEula": True}, + }, } ] @@ -2232,14 +2230,17 @@ def test_public_additional_model_data_source_should_pass_through(self): ) # THEN - assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + assert ( + additional_model_data_sources + == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) def test_multiple_public_additional_model_data_source_should_pass_through_both(self): # WHERE / WHEN additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( model_data_sources=( - self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ), model_access_configs=None, model_id=self.MOCK_PUBLIC_MODEL_ID, @@ -2248,23 +2249,24 @@ def test_multiple_public_additional_model_data_source_should_pass_through_both(s # THEN assert additional_model_data_sources == ( - self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL ) def test_public_additional_model_data_source_with_model_access_config_should_ignored_it(self): # WHERE / WHEN additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, - model_access_configs={ - self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) - }, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, model_id=self.MOCK_GATED_MODEL_ID, region=JUMPSTART_DEFAULT_REGION_NAME, ) # THEN - assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + assert ( + additional_model_data_sources + == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) def test_no_additional_model_data_source_should_pass_through(self): # WHERE / WHEN @@ -2284,26 +2286,27 @@ def test_gated_additional_model_data_source_should_accept_it(self): # WHERE / WHEN additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, - model_access_configs={ - self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) - }, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, model_id=self.MOCK_GATED_MODEL_ID, region=JUMPSTART_DEFAULT_REGION_NAME, ) # THEN - assert additional_model_data_sources == self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + assert ( + additional_model_data_sources + == self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) def test_multiple_gated_additional_model_data_source_should_accept_both(self): # WHERE / WHEN additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( model_data_sources=( - self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ), model_access_configs={ - self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True), - self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True), + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True), }, model_id=self.MOCK_GATED_MODEL_ID, region=JUMPSTART_DEFAULT_REGION_NAME, @@ -2311,35 +2314,37 @@ def test_multiple_gated_additional_model_data_source_should_accept_both(self): # THEN assert additional_model_data_sources == ( - self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL ) # Mixed Positive Cases - def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(self): + def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other( + self, + ): # WHERE / WHEN additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( model_data_sources=( - self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + - self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ), - model_access_configs={ - self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True) - }, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, model_id=self.MOCK_GATED_MODEL_ID, region=JUMPSTART_DEFAULT_REGION_NAME, ) # THEN assert additional_model_data_sources == ( - self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + - self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL ) # Test Gated Negative Tests - def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error(self): + def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error( + self, + ): # WHERE / WHEN / THEN with self.assertRaises(ValueError): utils._add_model_access_configs_to_model_data_sources( @@ -2354,33 +2359,37 @@ def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error with self.assertRaises(ValueError): utils._add_model_access_configs_to_model_data_sources( model_data_sources=( - self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + - self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL ), model_access_configs=None, model_id=self.MOCK_GATED_MODEL_ID, region=JUMPSTART_DEFAULT_REGION_NAME, ) - def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error(self): + def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error( + self, + ): # WHERE / WHEN / THEN with self.assertRaises(ValueError): utils._add_model_access_configs_to_model_data_sources( model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, model_access_configs={ - self.MOCK_PUBLIC_MODEL_ID:ModelAccessConfig(accept_eula=True) + self.MOCK_PUBLIC_MODEL_ID: ModelAccessConfig(accept_eula=True) }, model_id=self.MOCK_GATED_MODEL_ID, region=JUMPSTART_DEFAULT_REGION_NAME, ) - def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error(self): + def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error( + self, + ): # WHERE / WHEN / THEN with self.assertRaises(ValueError): utils._add_model_access_configs_to_model_data_sources( model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, model_access_configs={ - self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=False) + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False) }, model_id=self.MOCK_GATED_MODEL_ID, region=JUMPSTART_DEFAULT_REGION_NAME, diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 74489824be..bd870dc461 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -441,22 +441,23 @@ def append_instance_stat_metrics( def append_gated_draft_model_specs_to_jumpstart_model_spec(*args, **kwargs): augmented_spec = get_prototype_model_spec(*args, **kwargs) - augmented_spec.hosting_additional_data_sources = JumpStartAdditionalDataSources(spec={ - 'speculative_decoding': [ - { - 'channel_name': 'draft_model', - 'provider': { - 'name': 'JumpStart', - 'classification': 'gated' - }, - 'artifact_version': 'v1', - 'hosting_eula_key': 'fmhMetadata/eula/llama3_2Eula.txt', - 's3_data_source': { - 's3_uri': 'meta-textgeneration/meta-textgeneration-llama-3-2-1b-instruct/artifacts/inference-prepack/v1.0.0/', - 'compression_type': 'None', - 's3_data_type': 'S3Prefix' + + gated_s3_uri = "meta-textgeneration/meta-textgeneration-llama-3-2-1b-instruct/artifacts/inference-prepack/v1.0.0/" + augmented_spec.hosting_additional_data_sources = JumpStartAdditionalDataSources( + spec={ + "speculative_decoding": [ + { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": gated_s3_uri, + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, } - } - ] - }) + ] + } + ) return augmented_spec From 4d1e12b496498510c075765bb253ec39db510288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Fri, 15 Nov 2024 18:18:38 +0000 Subject: [PATCH 22/25] address validation comments --- src/sagemaker/serve/builder/model_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 501574c60e..6c9a41b99b 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1283,7 +1283,7 @@ def _model_builder_optimize_wrapper( # TRTLLM is used by Neo if the following are provided: # 1) a GPU instance type # 2) compilation config - gpu_instance_families = ["g4", "g5", "p4d"] + gpu_instance_families = ["g5", "g6", "p4d", "p5"] is_gpu_instance = optimization_instance_type and any( gpu_instance_family in optimization_instance_type for gpu_instance_family in gpu_instance_families From bf706ad4c46c61a695a313e0f63a8ef7bd6d6ad1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Fri, 15 Nov 2024 18:49:36 +0000 Subject: [PATCH 23/25] fix doc strings --- src/sagemaker/jumpstart/utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 15a1373c50..f7f56fd59d 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1542,10 +1542,19 @@ def _add_model_access_configs_to_model_data_sources( model_access_configs: Dict[str, ModelAccessConfig], model_id: str, region: str, -): - """Sets AcceptEula to True for gated speculative decoding models""" - print(model_data_sources) +) -> List[Dict[str, any]]: + """Iterate over the accept EULA configs to ensure all channels are matched + Args: + model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated + model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field + model_id (DeploymentConfigMetadata): Jumpstart mode id. + region (str): Region where the user is operating in. + Returns: + List[Dict[str, Any]]: List of model data sources with accept EULA configs applied + Raise: + ValueError if at least one channel that requires EULA acceptance as not passed. + """ if not model_data_sources: return model_data_sources From f121eb06a80212294f62375eb93ff87c9bb3106d Mon Sep 17 00:00:00 2001 From: Joseph Zhang Date: Fri, 15 Nov 2024 10:57:52 -0800 Subject: [PATCH 24/25] Add TRTLLM compilation + speculative decoding validation. --- src/sagemaker/serve/builder/model_builder.py | 14 ++++- .../serve/builder/test_model_builder.py | 55 +++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 6c9a41b99b..61af6953a2 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1283,7 +1283,7 @@ def _model_builder_optimize_wrapper( # TRTLLM is used by Neo if the following are provided: # 1) a GPU instance type # 2) compilation config - gpu_instance_families = ["g5", "g6", "p4d", "p5"] + gpu_instance_families = ["g5", "g6", "p4d", "p4de", "p5"] is_gpu_instance = optimization_instance_type and any( gpu_instance_family in optimization_instance_type for gpu_instance_family in gpu_instance_families @@ -1296,8 +1296,16 @@ def _model_builder_optimize_wrapper( keyword in self.model.lower() for keyword in llama_3_1_keywords ) - if is_gpu_instance and self.model and is_llama_3_1 and self.is_compiled: - raise ValueError("Compilation is not supported for Llama-3.1 with a GPU instance.") + if is_gpu_instance and self.model and self.is_compiled: + if is_llama_3_1: + raise ValueError( + "Compilation is not supported for Llama-3.1 with a GPU instance." + ) + if speculative_decoding_config: + raise ValueError( + "Compilation is not supported with speculative decoding with " + "a GPU instance." + ) self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) job_status = self.sagemaker_session.wait_for_optimization_job(job_name) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index fffb548245..2da09aece3 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -2891,3 +2891,58 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( output_path="s3://bucket/code/", ), ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "modelid"} + + sample_input = {"inputs": "dummy prompt", "parameters": {}} + + sample_output = [{"generated_text": "dummy response"}] + + dummy_schema_builder = SchemaBuilder(sample_input, sample_output) + + model_builder = ModelBuilder( + model="modelid", + schema_builder=dummy_schema_builder, + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + self.assertRaisesRegex( + ValueError, + "Compilation is not supported with speculative decoding with a GPU instance.", + lambda: model_builder.optimize( + job_name="job_name-123", + speculative_decoding_config={ + "ModelProvider": "custom", + "ModelSource": "s3://data-source", + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ), + ) From 9148e70e665327398babb0791d87b1c009542a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gary=20Wang=20=F0=9F=98=A4?= Date: Fri, 15 Nov 2024 23:36:59 +0000 Subject: [PATCH 25/25] address nits --- src/sagemaker/jumpstart/model.py | 2 -- src/sagemaker/jumpstart/utils.py | 2 +- tests/unit/sagemaker/jumpstart/test_utils.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 0f916e5ff6..65bb156ee3 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -817,14 +817,12 @@ def deploy( f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." ) - print(self.additional_model_data_sources) self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources( self.additional_model_data_sources, deploy_kwargs.model_access_configs, deploy_kwargs.model_id, deploy_kwargs.region, ) - print(self.additional_model_data_sources) try: predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f7f56fd59d..dfe3d7f1dd 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1548,7 +1548,7 @@ def _add_model_access_configs_to_model_data_sources( Args: model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field - model_id (DeploymentConfigMetadata): Jumpstart mode id. + model_id (DeploymentConfigMetadata): Jumpstart model id. region (str): Region where the user is operating in. Returns: List[Dict[str, Any]]: List of model data sources with accept EULA configs applied diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index a056c315ab..67681e2b7b 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -2253,7 +2253,7 @@ def test_multiple_public_additional_model_data_source_should_pass_through_both(s + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL ) - def test_public_additional_model_data_source_with_model_access_config_should_ignored_it(self): + def test_public_additional_model_data_source_with_model_access_config_should_ignore_it(self): # WHERE / WHEN additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,