diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py index a1f4567eb9..b64985cd65 100644 --- a/src/sagemaker/serve/builder/tei_builder.py +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -67,7 +67,7 @@ def __init__(self): self.role_arn = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, *args, **kwargs): """Placeholder docstring""" @abstractmethod @@ -164,15 +164,24 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa del kwargs["role"] if not _is_optimized(self.pysdk_model): - self._prepare_for_mode() + env_vars = {} + if str(Mode.LOCAL_CONTAINER) in self.modes: + # upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT + self.pysdk_model.model_data, env_vars = self._prepare_for_mode( + model_path=self.model_path, should_upload_artifacts=True + ) + else: + _, env_vars = self._prepare_for_mode() + + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: - self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"}) + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"}) else: # if has not been built for local container we must use cache # that hosting has write access to. - self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" self.pysdk_model.env["HF_HOME"] = "/tmp" self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" @@ -191,6 +200,9 @@ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa predictor = self._original_deploy(*args, **kwargs) + if "HF_HUB_OFFLINE" in self.pysdk_model.env: + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"}) + predictor.serializer = serializer predictor.deserializer = deserializer return predictor diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index 558a560a74..9bde777af2 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -94,7 +94,7 @@ def __init__(self): self.role_arn = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, *args, **kwargs): """Placeholder docstring""" @abstractmethod @@ -203,15 +203,24 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa del kwargs["role"] if not _is_optimized(self.pysdk_model): - self._prepare_for_mode() + env_vars = {} + if str(Mode.LOCAL_CONTAINER) in self.modes: + # upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT + self.pysdk_model.model_data, env_vars = self._prepare_for_mode( + model_path=self.model_path, should_upload_artifacts=True + ) + else: + _, env_vars = self._prepare_for_mode() + + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: - self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"}) + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"}) else: # if has not been built for local container we must use cache # that hosting has write access to. - self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" self.pysdk_model.env["HF_HOME"] = "/tmp" self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" @@ -242,7 +251,8 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa predictor = self._original_deploy(*args, **kwargs) - self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "0"}) + if "HF_HUB_OFFLINE" in self.pysdk_model.env: + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"}) predictor.serializer = serializer predictor.deserializer = deserializer diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index e064564961..570371e54d 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -84,7 +84,7 @@ def __init__(self): self.shared_libs = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, *args, **kwargs): """Abstract method""" def _create_transformers_model(self) -> Type[Model]: @@ -206,8 +206,6 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr else: raise ValueError("Mode %s is not supported!" % overwrite_mode) - self._set_instance() - serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer if self.mode == Mode.LOCAL_CONTAINER: @@ -227,6 +225,8 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr ) return predictor + self._set_instance(kwargs) + if "mode" in kwargs: del kwargs["mode"] if "role" in kwargs: @@ -234,7 +234,23 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr del kwargs["role"] if not _is_optimized(self.pysdk_model): - self._prepare_for_mode() + env_vars = {} + if str(Mode.LOCAL_CONTAINER) in self.modes: + # upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT + self.pysdk_model.model_data, env_vars = self._prepare_for_mode( + model_path=self.model_path, should_upload_artifacts=True + ) + else: + _, env_vars = self._prepare_for_mode() + + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) + + if ( + "SAGEMAKER_SERVE_SECRET_KEY" in self.pysdk_model.env + and not self.pysdk_model.env["SAGEMAKER_SERVE_SECRET_KEY"] + ): + del self.pysdk_model.env["SAGEMAKER_SERVE_SECRET_KEY"] if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True @@ -279,9 +295,11 @@ def _build_transformers_env(self): return self.pysdk_model - def _set_instance(self, **kwargs): + def _set_instance(self, kwargs): """Set the instance : Given the detected notebook type or provided instance type""" if self.mode == Mode.SAGEMAKER_ENDPOINT: + if "instance_type" in kwargs: + return if self.nb_instance_type and "instance_type" not in kwargs: kwargs.update({"instance_type": self.nb_instance_type}) logger.info("Setting instance type to %s", self.nb_instance_type) diff --git a/src/sagemaker/serve/model_server/tei/server.py b/src/sagemaker/serve/model_server/tei/server.py index 54abbea0da..94265e224f 100644 --- a/src/sagemaker/serve/model_server/tei/server.py +++ b/src/sagemaker/serve/model_server/tei/server.py @@ -17,7 +17,6 @@ MODE_DIR_BINDING = "/opt/ml/model/" _SHM_SIZE = "2G" _DEFAULT_ENV_VARS = { - "TRANSFORMERS_CACHE": "/opt/ml/model/", "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", } diff --git a/src/sagemaker/serve/model_server/tgi/server.py b/src/sagemaker/serve/model_server/tgi/server.py index 4d9686a89c..8ccc8e7ddc 100644 --- a/src/sagemaker/serve/model_server/tgi/server.py +++ b/src/sagemaker/serve/model_server/tgi/server.py @@ -17,7 +17,6 @@ MODE_DIR_BINDING = "/opt/ml/model/" _SHM_SIZE = "2G" _DEFAULT_ENV_VARS = { - "TRANSFORMERS_CACHE": "/opt/ml/model/", "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", } diff --git a/tests/unit/sagemaker/serve/builder/test_djl_builder.py b/tests/unit/sagemaker/serve/builder/test_djl_builder.py index 417430e45d..3ecd55e301 100644 --- a/tests/unit/sagemaker/serve/builder/test_djl_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_djl_builder.py @@ -44,6 +44,7 @@ "HF_MODEL_ID": "TheBloke/Llama-2-7b-chat-fp16", "TENSOR_PARALLEL_DEGREE": "1", "OPTION_DTYPE": "bf16", + "MODEL_LOADING_TIMEOUT": "1800", } mock_schema_builder = MagicMock() @@ -63,8 +64,13 @@ class TestDjlBuilder(unittest.TestCase): ) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") + @patch( + "sagemaker.serve.builder.djl_builder._get_default_djl_configurations", + return_value=(mock_default_configs, 128), + ) def test_build_deploy_for_djl_local_container( self, + mock_default_djl_config, mock_get_nb_instance, mock_get_ram_usage_mb, mock_is_jumpstart_model, @@ -125,8 +131,13 @@ def test_build_deploy_for_djl_local_container( "sagemaker.serve.builder.djl_builder._concurrent_benchmark", side_effect=[(0.03, 16), (0.10, 4), (0.15, 2)], ) + @patch( + "sagemaker.serve.builder.djl_builder._get_default_djl_configurations", + return_value=(mock_default_configs, 128), + ) def test_tune_for_djl_local_container( self, + mock_default_djl_config, mock_concurrent_benchmarks, mock_serial_benchmarks, mock_admissible_tensor_parallel_degrees, @@ -165,8 +176,10 @@ def test_tune_for_djl_local_container( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_deep_ping_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, @@ -204,8 +217,10 @@ def test_tune_for_djl_local_container_deep_ping_ex( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_load_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, @@ -245,8 +260,10 @@ def test_tune_for_djl_local_container_load_ex( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_oom_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, @@ -283,8 +300,10 @@ def test_tune_for_djl_local_container_oom_ex( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_invoke_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, diff --git a/tests/unit/sagemaker/serve/builder/test_tei_builder.py b/tests/unit/sagemaker/serve/builder/test_tei_builder.py index 4a75174bfc..2ede60290b 100644 --- a/tests/unit/sagemaker/serve/builder/test_tei_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_tei_builder.py @@ -20,10 +20,10 @@ from sagemaker.serve.utils.predictors import TeiLocalModePredictor -mock_model_id = "bert-base-uncased" -mock_prompt = "The man worked as a [MASK]." -mock_sample_input = {"inputs": mock_prompt} -mock_sample_output = [ +MOCK_MODEL_ID = "bert-base-uncased" +MOCK_PROMPT = "The man worked as a [MASK]." +MOCK_SAMPLE_INPUT = {"inputs": MOCK_PROMPT} +MOCK_SAMPLE_OUTPUT = [ { "score": 0.0974755585193634, "token": 10533, @@ -55,13 +55,14 @@ "sequence": "the man worked as a salesman.", }, ] -mock_schema_builder = MagicMock() -mock_schema_builder.sample_input = mock_sample_input -mock_schema_builder.sample_output = mock_sample_output +MOCK_SCHEMA_BUILDER = MagicMock() +MOCK_SCHEMA_BUILDER.sample_input = MOCK_SAMPLE_INPUT +MOCK_SCHEMA_BUILDER.sample_output = MOCK_SAMPLE_OUTPUT MOCK_IMAGE_CONFIG = ( "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" ) +MOCK_MODEL_PATH = "mock model path" class TestTEIBuilder(unittest.TestCase): @@ -70,57 +71,136 @@ class TestTEIBuilder(unittest.TestCase): return_value="ml.g5.24xlarge", ) @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) - def test_build_deploy_for_tei_local_container_and_remote_container( + def test_tei_builder_sagemaker_endpoint_mode_no_s3_upload_success( self, mock_get_nb_instance, mock_telemetry, ): + # verify SAGEMAKER_ENDPOINT deploy builder = ModelBuilder( - model=mock_model_id, - schema_builder=mock_schema_builder, + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.SAGEMAKER_ENDPOINT, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder._original_deploy = MagicMock() + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + def test_tei_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success( + self, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, vpc_config=MOCK_VPC_CONFIG, model_metadata={ "HF_TASK": "sentence-similarity", }, + model_path=MOCK_MODEL_PATH, ) builder._prepare_for_mode = MagicMock() builder._prepare_for_mode.side_effect = None - model = builder.build() builder.serve_settings.telemetry_opt_out = True - builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) assert model.vpc_config == MOCK_VPC_CONFIG assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" assert isinstance(predictor, TeiLocalModePredictor) - assert builder.nb_instance_type == "ml.g5.24xlarge" + # verify SAGEMAKER_ENDPOINT overwritten deploy builder._original_deploy = MagicMock() builder._prepare_for_mode.return_value = (None, {}) - predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") - assert "HF_MODEL_ID" in model.env + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.call_args_list[1].assert_called_once_with( + model_path=MOCK_MODEL_PATH, should_upload_artifacts=True + ) + + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + @patch("sagemaker.serve.builder.tei_builder._is_optimized", return_value=True) + def test_tei_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_is_optimized, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + vpc_config=MOCK_VPC_CONFIG, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + model.deploy(model_data_download_timeout=1800) + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + # verify that if optimized, no s3 upload occurs + builder._prepare_for_mode.assert_called_with() @patch( "sagemaker.serve.builder.tei_builder._get_nb_instance", return_value="ml.g5.24xlarge", ) @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) - def test_image_uri_override( + def test_tei_builder_image_uri_override_success( self, mock_get_nb_instance, mock_telemetry, ): builder = ModelBuilder( - model=mock_model_id, - schema_builder=mock_schema_builder, + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, image_uri=MOCK_IMAGE_CONFIG, model_metadata={ diff --git a/tests/unit/sagemaker/serve/builder/test_tgi_builder.py b/tests/unit/sagemaker/serve/builder/test_tgi_builder.py new file mode 100644 index 0000000000..c77dbfffd6 --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_tgi_builder.py @@ -0,0 +1,189 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import MagicMock, patch +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils.predictors import TgiLocalModePredictor + +MOCK_MODEL_ID = "meta-llama/Meta-Llama-3-8B" +MOCK_PROMPT = "The man worked as a [MASK]." +MOCK_SAMPLE_INPUT = {"inputs": "Hello, I'm a language model", "parameters": {"max_new_tokens": 128}} +MOCK_SAMPLE_OUTPUT = [{"generated_text": "Hello, I'm a language modeler."}] +MOCK_SCHEMA_BUILDER = MagicMock() +MOCK_SCHEMA_BUILDER.sample_input = MOCK_SAMPLE_INPUT +MOCK_SCHEMA_BUILDER.sample_output = MOCK_SAMPLE_OUTPUT +MOCK_IMAGE_CONFIG = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" +) +MOCK_MODEL_PATH = "mock model path" + + +class TestTGIBuilder(TestCase): + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + def test_tgi_builder_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify SAGEMAKER_ENDPOINT deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder._original_deploy = MagicMock() + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + def test_tgi_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success( + self, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TgiLocalModePredictor) + assert builder.nb_instance_type == "ml.g5.24xlarge" + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.call_args_list[1].assert_called_once_with( + model_path=MOCK_MODEL_PATH, should_upload_artifacts=True + ) + + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + @patch("sagemaker.serve.builder.tgi_builder._is_optimized", return_value=True) + def test_tgi_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_is_optimized, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + model.deploy(model_data_download_timeout=1800) + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + # verify that if optimized, no s3 upload occurs + builder._prepare_for_mode.assert_called_with() diff --git a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py index 9ea797adc2..a5e269ea51 100644 --- a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py @@ -20,10 +20,10 @@ from sagemaker.serve.utils.predictors import TransformersLocalModePredictor -mock_model_id = "bert-base-uncased" -mock_prompt = "The man worked as a [MASK]." -mock_sample_input = {"inputs": mock_prompt} -mock_sample_output = [ +MOCK_MODEL_ID = "bert-base-uncased" +MOCK_PROMPT = "The man worked as a [MASK]." +MOCK_SAMPLE_INPUT = {"inputs": MOCK_PROMPT} +MOCK_SAMPLE_OUTPUT = [ { "score": 0.0974755585193634, "token": 10533, @@ -55,13 +55,14 @@ "sequence": "the man worked as a salesman.", }, ] -mock_schema_builder = MagicMock() -mock_schema_builder.sample_input = mock_sample_input -mock_schema_builder.sample_output = mock_sample_output +MOCK_SCHEMA_BUILDER = MagicMock() +MOCK_SCHEMA_BUILDER.sample_input = MOCK_SAMPLE_INPUT +MOCK_SCHEMA_BUILDER.sample_output = MOCK_SAMPLE_OUTPUT MOCK_IMAGE_CONFIG = ( "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" ) +MOCK_MODEL_PATH = "mock model path" class TestTransformersBuilder(unittest.TestCase): @@ -70,54 +71,124 @@ class TestTransformersBuilder(unittest.TestCase): return_value="ml.g5.24xlarge", ) @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) - def test_build_deploy_for_transformers_local_container_and_remote_container( + def test_transformers_builder_sagemaker_endpoint_mode_no_s3_upload_success( self, mock_get_nb_instance, mock_telemetry, ): + # verify SAGEMAKER_ENDPOINT deploy builder = ModelBuilder( - model=mock_model_id, - schema_builder=mock_schema_builder, + model=MOCK_MODEL_ID, schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.SAGEMAKER_ENDPOINT + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder._original_deploy = MagicMock() + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + def test_transformers_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success( + self, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, vpc_config=MOCK_VPC_CONFIG, + model_path=MOCK_MODEL_PATH, ) builder._prepare_for_mode = MagicMock() builder._prepare_for_mode.side_effect = None - model = builder.build() builder.serve_settings.telemetry_opt_out = True - builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) assert model.vpc_config == MOCK_VPC_CONFIG assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" assert isinstance(predictor, TransformersLocalModePredictor) - assert builder.nb_instance_type == "ml.g5.24xlarge" + # verify SAGEMAKER_ENDPOINT overwritten deploy builder._original_deploy = MagicMock() builder._prepare_for_mode.return_value = (None, {}) - predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") - assert "HF_MODEL_ID" in model.env + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.call_args_list[1].assert_called_once_with( + model_path=MOCK_MODEL_PATH, should_upload_artifacts=True + ) + + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + @patch("sagemaker.serve.builder.transformers_builder._is_optimized", return_value=True) + def test_transformers_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_is_optimized, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + vpc_config=MOCK_VPC_CONFIG, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + model.deploy(model_data_download_timeout=1800) + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + builder._prepare_for_mode.assert_called_once_with() @patch( "sagemaker.serve.builder.transformers_builder._get_nb_instance", return_value="ml.g5.24xlarge", ) @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) - def test_image_uri_override( + def test_transformers_builder_image_uri_override_success( self, mock_get_nb_instance, mock_telemetry, ): builder = ModelBuilder( - model=mock_model_id, - schema_builder=mock_schema_builder, + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, image_uri=MOCK_IMAGE_CONFIG, ) @@ -152,15 +223,15 @@ def test_image_uri_override( ) @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) @patch( - "sagemaker.huggingface.llm_utils.get_huggingface_model_metadata", - return_value=None, + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={}, ) - def test_failure_hf_md( + def test_transformers_builder_empty_hf_md_defaults_to_transformers_success( self, mock_model_md, mock_get_nb_instance, mock_telemetry, mock_build_for_transformers ): builder = ModelBuilder( - model=mock_model_id, - schema_builder=mock_schema_builder, + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, ) diff --git a/tests/unit/sagemaker/serve/model_server/tei/test_server.py b/tests/unit/sagemaker/serve/model_server/tei/test_server.py index cc1226702f..47399c1fad 100644 --- a/tests/unit/sagemaker/serve/model_server/tei/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/tei/test_server.py @@ -65,7 +65,6 @@ def test_start_invoke_destroy_local_tei_server(self, mock_requests): auto_remove=True, volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}}, environment={ - "TRANSFORMERS_CACHE": "/opt/ml/model/", "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", "KEY": "VALUE", diff --git a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_serving.py b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_serving.py new file mode 100644 index 0000000000..33371fc584 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_serving.py @@ -0,0 +1,283 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import Mock, patch +from pathlib import Path +from sagemaker.serve.model_server.tgi.server import LocalTgiServing, SageMakerTgiServing + +MOCK_IMAGE = "mock image" +MOCK_MODEL_PATH = "mock model path" +MOCK_SECRET_KEY = "mock secret key" +MOCK_ENV_VARS = {"mock key": "mock value"} +MOCK_SAGEMAKER_SESSION = Mock() +MOCK_S3_MODEL_DATA_URL = "mock s3 path" +MOCK_MODEL_DATA_URL = "mock model data url" + +EXPECTED_MODE_DIR_BINDING = "/opt/ml/model/" +EXPECTED_SHM_SIZE = "2G" +EXPECTED_UPDATED_ENV_VARS = { + "HF_HOME": "/opt/ml/model/", + "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", + "mock key": "mock value", +} +EXPECTED_MODEL_DATA = { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": MOCK_MODEL_DATA_URL + "/", + } +} + + +class TestLocalTgiServing(TestCase): + def test_tgi_serving_runs_container_non_jumpstart_success(self): + # WHERE + mock_container_client = Mock() + mock_container = Mock() + mock_container_client.containers.run.return_value = mock_container + localTgiServing = LocalTgiServing() + + # WHEN + localTgiServing._start_tgi_serving( + mock_container_client, + MOCK_IMAGE, + MOCK_MODEL_PATH, + MOCK_SECRET_KEY, + MOCK_ENV_VARS, + False, + ) + + # THEN + mock_container_client.containers.run.assert_called_once_with( + MOCK_IMAGE, + shm_size=EXPECTED_SHM_SIZE, + device_requests=[ + { + "Driver": "", + "Count": -1, + "DeviceIDs": [], + "Capabilities": [["gpu"]], + "Options": {}, + } + ], + network_mode="host", + detach=True, + auto_remove=True, + volumes={ + Path(MOCK_MODEL_PATH).joinpath("code"): { + "bind": EXPECTED_MODE_DIR_BINDING, + "mode": "rw", + } + }, + environment=EXPECTED_UPDATED_ENV_VARS, + ) + assert localTgiServing.container == mock_container + + def test_tgi_serving_runs_container_jumpstart_success(self): + # WHERE + mock_container_client = Mock() + mock_container = Mock() + mock_container_client.containers.run.return_value = mock_container + localTgiServing = LocalTgiServing() + + # WHEN + localTgiServing._start_tgi_serving( + mock_container_client, MOCK_IMAGE, MOCK_MODEL_PATH, MOCK_SECRET_KEY, MOCK_ENV_VARS, True + ) + + # THEN + mock_container_client.containers.run.assert_called_once_with( + MOCK_IMAGE, + ["--model-id", EXPECTED_MODE_DIR_BINDING], + shm_size=EXPECTED_SHM_SIZE, + device_requests=[ + { + "Driver": "", + "Count": -1, + "DeviceIDs": [], + "Capabilities": [["gpu"]], + "Options": {}, + } + ], + network_mode="host", + detach=True, + auto_remove=True, + volumes={ + Path(MOCK_MODEL_PATH).joinpath("code"): { + "bind": EXPECTED_MODE_DIR_BINDING, + "mode": "rw", + } + }, + environment=MOCK_ENV_VARS, + ) + assert localTgiServing.container == mock_container + + +class TestSageMakerTgiServing(TestCase): + + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + def test_tgi_serving_upload_tgi_artifacts_s3_url_passed_success( + self, + mock_s3_uploader, + mock_s3_path_join, + mock_determine_bucket_and_prefix, + mock_fw_utils, + mock_parse_s3_url, + mock_is_s3_uri, + ): + # WHERE + mock_is_s3_uri.return_value = False + mock_parse_s3_url.return_value = ("mock_bucket_1", "mock_prefix_1") + mock_fw_utils.model_code_key_prefix.return_value = "mock_code_key_prefix" + mock_determine_bucket_and_prefix.return_value = ("mock_bucket_2", "mock_prefix_2") + mock_s3_path_join.return_value = "mock_s3_location" + mock_s3_uploader.upload.return_value = MOCK_MODEL_DATA_URL + + sagemakerTgiServing = SageMakerTgiServing() + + # WHEN + ret_model_data, ret_env_vars = sagemakerTgiServing._upload_tgi_artifacts( + MOCK_MODEL_PATH, + MOCK_SAGEMAKER_SESSION, + False, + MOCK_S3_MODEL_DATA_URL, + MOCK_IMAGE, + MOCK_ENV_VARS, + True, + ) + + # THEN + mock_is_s3_uri.assert_called_once_with(MOCK_MODEL_PATH) + mock_parse_s3_url.assert_called_once_with(url=MOCK_S3_MODEL_DATA_URL) + mock_fw_utils.model_code_key_prefix.assert_called_once_with( + "mock_prefix_1", None, MOCK_IMAGE + ) + mock_determine_bucket_and_prefix.assert_called_once_with( + bucket="mock_bucket_1", + key_prefix="mock_code_key_prefix", + sagemaker_session=MOCK_SAGEMAKER_SESSION, + ) + mock_s3_path_join.assert_called_once_with("s3://", "mock_bucket_2", "mock_prefix_2", "code") + mock_s3_uploader.upload.assert_called_once_with( + f"{MOCK_MODEL_PATH}/code", "mock_s3_location", None, MOCK_SAGEMAKER_SESSION + ) + assert ret_model_data == EXPECTED_MODEL_DATA + assert ret_env_vars == EXPECTED_UPDATED_ENV_VARS + + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + def test_tgi_serving_upload_tgi_artifacts_jumpstart_success( + self, + mock_s3_uploader, + mock_s3_path_join, + mock_determine_bucket_and_prefix, + mock_fw_utils, + mock_parse_s3_url, + mock_is_s3_uri, + ): + # WHERE + mock_is_s3_uri.return_value = False + mock_parse_s3_url.return_value = ("mock_bucket_1", "mock_prefix_1") + mock_fw_utils.model_code_key_prefix.return_value = "mock_code_key_prefix" + mock_determine_bucket_and_prefix.return_value = ("mock_bucket_2", "mock_prefix_2") + mock_s3_path_join.return_value = "mock_s3_location" + mock_s3_uploader.upload.return_value = MOCK_MODEL_DATA_URL + + sagemakerTgiServing = SageMakerTgiServing() + + # WHEN + ret_model_data, ret_env_vars = sagemakerTgiServing._upload_tgi_artifacts( + MOCK_MODEL_PATH, + MOCK_SAGEMAKER_SESSION, + True, + MOCK_S3_MODEL_DATA_URL, + MOCK_IMAGE, + MOCK_ENV_VARS, + True, + ) + + # THEN + mock_is_s3_uri.assert_called_once_with(MOCK_MODEL_PATH) + mock_parse_s3_url.assert_called_once_with(url=MOCK_S3_MODEL_DATA_URL) + mock_fw_utils.model_code_key_prefix.assert_called_once_with( + "mock_prefix_1", None, MOCK_IMAGE + ) + mock_determine_bucket_and_prefix.assert_called_once_with( + bucket="mock_bucket_1", + key_prefix="mock_code_key_prefix", + sagemaker_session=MOCK_SAGEMAKER_SESSION, + ) + mock_s3_path_join.assert_called_once_with("s3://", "mock_bucket_2", "mock_prefix_2", "code") + mock_s3_uploader.upload.assert_called_once_with( + f"{MOCK_MODEL_PATH}/code", "mock_s3_location", None, MOCK_SAGEMAKER_SESSION + ) + assert ret_model_data == EXPECTED_MODEL_DATA + assert ret_env_vars == {} + + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + def test_tgi_serving_upload_tgi_artifacts( + self, + mock_s3_uploader, + mock_s3_path_join, + mock_determine_bucket_and_prefix, + mock_fw_utils, + mock_parse_s3_url, + mock_is_s3_uri, + ): + # WHERE + mock_is_s3_uri.return_value = True + + sagemakerTgiServing = SageMakerTgiServing() + + # WHEN + ret_model_data, ret_env_vars = sagemakerTgiServing._upload_tgi_artifacts( + MOCK_MODEL_PATH, + MOCK_SAGEMAKER_SESSION, + False, + MOCK_S3_MODEL_DATA_URL, + MOCK_IMAGE, + MOCK_ENV_VARS, + True, + ) + + # THEN + mock_is_s3_uri.assert_called_once_with(MOCK_MODEL_PATH) + assert not mock_parse_s3_url.called + assert not mock_fw_utils.model_code_key_prefix.called + assert not mock_determine_bucket_and_prefix.called + assert not mock_s3_path_join.called + assert not mock_s3_uploader.upload.called + assert ret_model_data == { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": MOCK_MODEL_PATH + "/", + } + } + assert ret_env_vars == EXPECTED_UPDATED_ENV_VARS