diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index fb6f60b9d0..013f2bc79b 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -603,7 +603,6 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str): s3_upload_path, env_vars_sagemaker = self._prepare_for_mode() self.pysdk_model.model_data = s3_upload_path self.pysdk_model.env.update(env_vars_sagemaker) - elif overwrite_mode == Mode.LOCAL_CONTAINER: self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER self._prepare_for_mode() diff --git a/src/sagemaker/serve/builder/requirements_manager.py b/src/sagemaker/serve/builder/requirements_manager.py new file mode 100644 index 0000000000..7497bd471a --- /dev/null +++ b/src/sagemaker/serve/builder/requirements_manager.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Requirements Manager class to pull in client dependencies from a .txt or .yml file""" +from __future__ import absolute_import +import logging +import os +import subprocess + +from typing import Optional + +logger = logging.getLogger(__name__) + + +class RequirementsManager: + """Manages dependency installation by detecting file types""" + + def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -> str: + """Detects the type of file dependencies will be installed from + + If a req.txt or conda.yml file is provided, it verifies their existence and + returns the local file path + + Args: + dependencies (str): Local path where dependencies file exists. + + Returns: + file path of the existing or generated dependencies file + """ + _dependencies = dependencies or self._detect_conda_env_and_local_dependencies() + + # Dependencies specified as either req.txt or conda_env.yml + if _dependencies.endswith(".txt"): + self._install_requirements_txt() + elif _dependencies.endswith(".yml"): + self._update_conda_env_in_path() + else: + raise ValueError(f'Invalid dependencies provided: "{_dependencies}"') + + def _install_requirements_txt(self): + """Install requirements.txt file using pip""" + logger.info("Running command to pip install") + subprocess.run("pip install -r in_process_requirements.txt", shell=True, check=True) + logger.info("Command ran successfully") + + def _update_conda_env_in_path(self): + """Update conda env using conda yml file""" + logger.info("Updating conda env") + subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True) + logger.info("Conda env updated successfully") + + def _get_active_conda_env_name(self) -> str: + """Returns the conda environment name from the set environment variable. None otherwise.""" + return os.getenv("CONDA_DEFAULT_ENV") + + def _get_active_conda_env_prefix(self) -> str: + """Returns the conda prefix from the set environment variable. None otherwise.""" + return os.getenv("CONDA_PREFIX") + + def _detect_conda_env_and_local_dependencies(self) -> str: + """Generates dependencies list from the user's local runtime. + + Raises RuntimeEnvironmentError if not able to. + + Currently supports: conda environments + """ + + # Try to capture dependencies from the conda environment, if any. + conda_env_name = self._get_active_conda_env_name() + logger.info("Found conda_env_name: '%s'", conda_env_name) + conda_env_prefix = None + + if conda_env_name is None: + conda_env_prefix = self._get_active_conda_env_prefix() + + if conda_env_name is None and conda_env_prefix is None: + local_dependencies_path = os.path.join(os.getcwd(), "in_process_requirements.txt") + logger.info(local_dependencies_path) + + return local_dependencies_path + + if conda_env_name == "base": + logger.warning( + "We recommend using an environment other than base to " + "isolate your project dependencies from conda dependencies" + ) + + local_dependencies_path = os.path.join(os.getcwd(), "conda_in_process.yml") + logger.info(local_dependencies_path) + + return local_dependencies_path diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index 570371e54d..e5a616ea4b 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -17,6 +17,7 @@ from abc import ABC, abstractmethod from typing import Type from pathlib import Path +import subprocess from packaging.version import Version from sagemaker.model import Model @@ -41,6 +42,8 @@ from sagemaker.serve.utils.telemetry_logger import _capture_telemetry from sagemaker.base_predictor import PredictorBase from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata +from sagemaker.serve.builder.requirements_manager import RequirementsManager + logger = logging.getLogger(__name__) DEFAULT_TIMEOUT = 1800 @@ -376,6 +379,9 @@ def _build_for_transformers(self): save_pkl(code_path, (self.inference_spec, self.schema_builder)) logger.info("PKL file saved to file: %s", code_path) + if self.mode == Mode.IN_PROCESS: + self._create_conda_env() + self._auto_detect_container() self.secret_key = prepare_for_mms( @@ -394,3 +400,11 @@ def _build_for_transformers(self): if self.sagemaker_session: self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model + + def _create_conda_env(self): + """Creating conda environment by running commands""" + + try: + RequirementsManager().capture_and_install_dependencies(self) + except subprocess.CalledProcessError: + print("Failed to create and activate conda environment.") diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index ccb73d8cb6..8586fa85fb 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -31,7 +31,7 @@ def _start_serving( secret_key: str, env_vars: dict, ): - """Placeholder docstring""" + """Initializes the start of the server""" env = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "SAGEMAKER_PROGRAM": "inference.py", @@ -59,7 +59,7 @@ def _start_serving( ) def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): - """Placeholder docstring""" + """Invokes MMS server by hitting the docker host""" try: response = requests.post( f"http://{get_docker_host()}:8080/invocations", @@ -73,7 +73,7 @@ def _invoke_multi_model_server_serving(self, request: object, content_type: str, raise Exception("Unable to send request to the local container server") from e def _multi_model_server_deep_ping(self, predictor: PredictorBase): - """Placeholder docstring""" + """Deep ping in order to ensure prediction""" response = None try: response = predictor.predict(self.schema_builder.sample_input) diff --git a/src/sagemaker/serve/utils/conda_in_process.yml b/src/sagemaker/serve/utils/conda_in_process.yml new file mode 100644 index 0000000000..6379812840 --- /dev/null +++ b/src/sagemaker/serve/utils/conda_in_process.yml @@ -0,0 +1,113 @@ +name: conda_env +channels: + - defaults +dependencies: + - accelerate>=0.24.1,<=0.27.0 + - sagemaker_schema_inference_artifacts>=0.0.5 + - uvicorn>=0.30.1 + - fastapi>=0.111.0 + - nest-asyncio + - pip>=23.0.1 + - attrs>=23.1.0,<24 + - boto3>=1.34.142,<2.0 + - cloudpickle==2.2.1 + - google-pasta + - numpy>=1.9.0,<2.0 + - protobuf>=3.12,<5.0 + - smdebug_rulesconfig==1.0.1 + - importlib-metadata>=1.4.0,<7.0 + - packaging>=20.0 + - pandas + - pathos + - schema + - PyYAML~=6.0 + - jsonschema + - platformdirs + - tblib>=1.7.0,<4 + - urllib3>=1.26.8,<3.0.0 + - requests + - docker + - tqdm + - psutil + - pip: + - altair>=4.2.2 + - anyio>=3.6.2 + - awscli>=1.27.114 + - blinker>=1.6.2 + - botocore>=1.29.114 + - cachetools>=5.3.0 + - certifi==2022.12.7 + - harset-normalizer>=3.1.0 + - click>=8.1.3 + - cloudpickle>=2.2.1 + - colorama>=0.4.4 + - contextlib2>=21.6.0 + - decorator>=5.1.1 + - dill>=0.3.6 + - docutils>=0.16 + - entrypoints>=0.4 + - filelock>=3.11.0 + - gitdb>=4.0.10 + - gitpython>=3.1.31 + - gunicorn>=20.1.0 + - h11>=0.14.0 + - huggingface-hub>=0.13.4 + - idna>=3.4 + - importlib-metadata>=4.13.0 + - jinja2>=3.1.2 + - jmespath>=1.0.1 + - jsonschema>=4.17.3 + - markdown-it-py>=2.2.0 + - markupsafe>=2.1.2 + - mdurl>=0.1.2 + - mpmath>=1.3.0 + - multiprocess>=0.70.14 + - networkx>=3.1 + - packaging>=23.1 + - pandas>=1.5.3 + - pathos>=0.3.0 + - pillow>=9.5.0 + - platformdirs>=3.2.0 + - pox>=0.3.2 + - ppft>=1.7.6.6 + - protobuf>=3.20.3 + - protobuf3-to-dict>=0.1.5 + - pyarrow>=11.0.0 + - pyasn1>=0.4.8 + - pydantic>=1.10.7 + - pydeck>=0.8.1b0 + - pygments>=2.15.1 + - pympler>=1.0.1 + - pyrsistent>=0.19.3 + - python-dateutil>=2.8.2 + - pytz>=2023.3 + - pytz-deprecation-shim>=0.1.0.post0 + - pyyaml>=5.4.1 + - regex>=2023.3.23 + - requests>=2.28.2 + - rich>=13.3.4 + - rsa>=4.7.2 + - s3transfer>=0.6.0 + - sagemaker>=2.148.0 + - schema>=0.7.5 + - six>=1.16.0 + - smdebug-rulesconfig>=1.0.1 + - smmap==5.0.0 + - sniffio>=1.3.0 + - starlette>=0.26.1 + - streamlit>=1.21.0 + - sympy>=1.11.1 + - tblib>=1.7.0 + - tokenizers>=0.13.3 + - toml>=0.10.2 + - toolz>=0.12.0 + - torch>=2.0.0 + - tornado>=6.3 + - tqdm>=4.65.0 + - transformers>=4.28.1 + - typing-extensions>=4.5.0 + - tzdata>=2023.3 + - tzlocal>=4.3 + - urllib3>=1.26.15 + - validators>=0.20.0 + - zipp>=3.15.0 diff --git a/src/sagemaker/serve/utils/exceptions.py b/src/sagemaker/serve/utils/exceptions.py index 72b9083072..30b22ba869 100644 --- a/src/sagemaker/serve/utils/exceptions.py +++ b/src/sagemaker/serve/utils/exceptions.py @@ -1,4 +1,4 @@ -"""Placeholder Docstring""" +"""Exceptions used across different model builder invocations""" from __future__ import absolute_import diff --git a/src/sagemaker/serve/utils/in_process_requirements.txt b/src/sagemaker/serve/utils/in_process_requirements.txt new file mode 100644 index 0000000000..cb7915d78b --- /dev/null +++ b/src/sagemaker/serve/utils/in_process_requirements.txt @@ -0,0 +1,85 @@ +altair>=4.2.2 +anyio>=3.6.2 +awscli>=1.27.114 +blinker>=1.6.2 +botocore>=1.29.114 +cachetools>=5.3.0 +certifi==2022.12.7 +harset-normalizer>=3.1.0 +click>=8.1.3 +cloudpickle>=2.2.1 +colorama>=0.4.4 +contextlib2>=21.6.0 +decorator>=5.1.1 +dill>=0.3.6 +docutils>=0.16 +entrypoints>=0.4 +filelock>=3.11.0 +gitdb>=4.0.10 +gitpython>=3.1.31 +gunicorn>=20.1.0 +h11>=0.14.0 +huggingface-hub>=0.13.4 +idna>=3.4 +importlib-metadata>=4.13.0 +jinja2>=3.1.2 +jmespath>=1.0.1 +jsonschema>=4.17.3 +markdown-it-py>=2.2.0 +markupsafe>=2.1.2 +mdurl>=0.1.2 +mpmath>=1.3.0 +multiprocess>=0.70.14 +networkx>=3.1 +packaging>=23.1 +pandas>=1.5.3 +pathos>=0.3.0 +pillow>=9.5.0 +platformdirs>=3.2.0 +pox>=0.3.2 +ppft>=1.7.6.6 +protobuf>=3.20.3 +protobuf3-to-dict>=0.1.5 +pyarrow>=11.0.0 +pyasn1>=0.4.8 +pydantic>=1.10.7 +pydeck>=0.8.1b0 +pygments>=2.15.1 +pympler>=1.0.1 +pyrsistent>=0.19.3 +python-dateutil>=2.8.2 +pytz>=2023.3 +pytz-deprecation-shim>=0.1.0.post0 +pyyaml>=5.4.1 +regex>=2023.3.23 +requests>=2.28.2 +rich>=13.3.4 +rsa>=4.7.2 +s3transfer>=0.6.0 +sagemaker>=2.148.0 +schema>=0.7.5 +six>=1.16.0 +smdebug-rulesconfig>=1.0.1 +smmap==5.0.0 +sniffio>=1.3.0 +starlette>=0.26.1 +streamlit>=1.21.0 +sympy>=1.11.1 +tblib>=1.7.0 +tokenizers>=0.13.3 +toml>=0.10.2 +toolz>=0.12.0 +torch>=2.0.0 +tornado>=6.3 +tqdm>=4.65.0 +transformers>=4.28.1 +typing-extensions>=4.5.0 +tzdata>=2023.3 +tzlocal>=4.3 +urllib3>=1.26.15 +validators>=0.20.0 +zipp>=3.15.0 +uvicorn>=0.30.1 +fastapi>=0.111.0 +nest-asyncio +transformers diff --git a/tests/unit/sagemaker/serve/builder/test_requirements_manager.py b/tests/unit/sagemaker/serve/builder/test_requirements_manager.py new file mode 100644 index 0000000000..02833c81c0 --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_requirements_manager.py @@ -0,0 +1,82 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, call + +from sagemaker.serve.builder.requirements_manager import RequirementsManager + + +class TestRequirementsManager(unittest.TestCase): + + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._update_conda_env_in_path" + ) + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._install_requirements_txt" + ) + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._detect_conda_env_and_local_dependencies" + ) + def test_capture_and_install_dependencies( + self, + mock_detect_conda_env_and_local_dependencies, + mock_install_requirements_txt, + mock_update_conda_env_in_path, + ) -> str: + + mock_detect_conda_env_and_local_dependencies.side_effect = lambda: ".txt" + RequirementsManager().capture_and_install_dependencies() + mock_install_requirements_txt.assert_called_once() + + mock_detect_conda_env_and_local_dependencies.side_effect = lambda: ".yml" + RequirementsManager().capture_and_install_dependencies() + mock_update_conda_env_in_path.assert_called_once() + + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._detect_conda_env_and_local_dependencies" + ) + def test_capture_and_install_dependencies_fail( + self, mock_detect_conda_env_and_local_dependencies + ) -> str: + mock_dependencies = "mock.ini" + mock_detect_conda_env_and_local_dependencies.side_effect = lambda: "invalid requirement" + self.assertRaises( + ValueError, + lambda: RequirementsManager().capture_and_install_dependencies(mock_dependencies), + ) + + @patch("sagemaker.serve.builder.requirements_manager.logger") + @patch("sagemaker.serve.builder.requirements_manager.subprocess") + def test_install_requirements_txt(self, mock_subprocess, mock_logger): + + RequirementsManager()._install_requirements_txt() + + calls = [call("Running command to pip install"), call("Command ran successfully")] + mock_logger.info.assert_has_calls(calls) + mock_subprocess.run.assert_called_once_with( + "pip install -r in_process_requirements.txt", shell=True, check=True + ) + + @patch("sagemaker.serve.builder.requirements_manager.logger") + @patch("sagemaker.serve.builder.requirements_manager.subprocess") + def test_update_conda_env_in_path(self, mock_subprocess, mock_logger): + + RequirementsManager()._update_conda_env_in_path() + + calls = [call("Updating conda env"), call("Conda env updated successfully")] + mock_logger.info.assert_has_calls(calls) + mock_subprocess.run.assert_called_once_with( + "conda env update -f conda_in_process.yml", shell=True, check=True + )