diff --git a/requirements/extras/huggingface_requirements.txt b/requirements/extras/huggingface_requirements.txt index c7ec458ea5..3ee6208618 100644 --- a/requirements/extras/huggingface_requirements.txt +++ b/requirements/extras/huggingface_requirements.txt @@ -1,2 +1,5 @@ accelerate>=0.24.1,<=0.27.0 sagemaker_schema_inference_artifacts>=0.0.5 +uvicorn>=0.30.1 +fastapi>=0.111.0 +nest-asyncio diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 71b4ad6256..bbe0061500 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -40,3 +40,6 @@ schema==0.7.5 tensorflow>=2.1,<=2.16 mlflow>=2.12.2,<2.13 huggingface_hub>=0.23.4 +uvicorn>=0.30.1 +fastapi>=0.111.0 +nest-asyncio diff --git a/src/sagemaker/serve/app.py b/src/sagemaker/serve/app.py new file mode 100644 index 0000000000..fd9dd6a93a --- /dev/null +++ b/src/sagemaker/serve/app.py @@ -0,0 +1,100 @@ +"""FastAPI requests""" + +from __future__ import absolute_import + +import asyncio +import logging +import threading +from typing import Optional + + +logger = logging.getLogger(__name__) + + +try: + import uvicorn +except ImportError: + logger.error("Unable to import uvicorn, check if uvicorn is installed.") + + +try: + from transformers import pipeline +except ImportError: + logger.error("Unable to import transformers, check if transformers is installed.") + + +try: + from fastapi import FastAPI, Request, APIRouter +except ImportError: + logger.error("Unable to import fastapi, check if fastapi is installed.") + + +class InProcessServer: + """Placeholder docstring""" + + def __init__(self, model_id: Optional[str] = None, task: Optional[str] = None): + self._thread = None + self._loop = None + self._stop_event = asyncio.Event() + self._router = APIRouter() + self._model_id = model_id + self._task = task + self.server = None + self.port = None + self.host = None + # TODO: Pick up device automatically. + self._generator = pipeline(task, model=model_id, device="cpu") + + # pylint: disable=unused-variable + @self._router.post("/generate") + async def generate_text(prompt: Request): + """Placeholder docstring""" + str_prompt = await prompt.json() + str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt + + generated_text = self._generator( + str_prompt, max_length=30, num_return_sequences=1, truncation=True + ) + return generated_text + + self._create_server() + + def _create_server(self): + """Placeholder docstring""" + app = FastAPI() + app.include_router(self._router) + + config = uvicorn.Config( + app, + host="127.0.0.1", + port=9007, + log_level="info", + loop="asyncio", + reload=True, + use_colors=True, + ) + + self.server = uvicorn.Server(config) + self.host = config.host + self.port = config.port + + def start_server(self): + """Starts the uvicorn server.""" + if not (self._thread and self._thread.is_alive()): + logger.info("Waiting for a connection...") + self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True) + self._thread.start() + + def stop_server(self): + """Destroys the uvicorn server.""" + # TODO: Implement me. + + def _start_run_async_in_thread(self): + """Placeholder docstring""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._serve()) + + async def _serve(self): + """Placeholder docstring""" + await self.server.serve() diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 1431965317..a919aa7342 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -812,7 +812,7 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None: self.dependencies.update({"requirements": mlflow_model_dependency_path}) # Model Builder is a class to build the model for deployment. - # It supports two* modes of deployment + # It supports three modes of deployment # 1/ SageMaker Endpoint # 2/ Local launch with container # 3/ In process mode with Transformers server in beta release diff --git a/src/sagemaker/serve/builder/requirements_manager.py b/src/sagemaker/serve/builder/requirements_manager.py index 7497bd471a..a8b41dba40 100644 --- a/src/sagemaker/serve/builder/requirements_manager.py +++ b/src/sagemaker/serve/builder/requirements_manager.py @@ -36,7 +36,7 @@ def capture_and_install_dependencies(self, dependencies: Optional[str] = None) - Returns: file path of the existing or generated dependencies file """ - _dependencies = dependencies or self._detect_conda_env_and_local_dependencies() + _dependencies = dependencies or self._detect_conda_env_and_local_dependencies # Dependencies specified as either req.txt or conda_env.yml if _dependencies.endswith(".txt"): diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index e3f1f15cf7..b380dc8455 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -421,6 +421,6 @@ def _create_conda_env(self): """Creating conda environment by running commands""" try: - RequirementsManager().capture_and_install_dependencies(self) + RequirementsManager().capture_and_install_dependencies except subprocess.CalledProcessError: print("Failed to create and activate conda environment.") diff --git a/src/sagemaker/serve/mode/in_process_mode.py b/src/sagemaker/serve/mode/in_process_mode.py index dc3b4fd74f..60d4f91e34 100644 --- a/src/sagemaker/serve/mode/in_process_mode.py +++ b/src/sagemaker/serve/mode/in_process_mode.py @@ -1,6 +1,7 @@ """Module that defines the InProcessMode class""" from __future__ import absolute_import + from pathlib import Path import logging from typing import Dict, Type @@ -11,7 +12,7 @@ from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serve.utils.types import ModelServer -from sagemaker.serve.utils.exceptions import LocalDeepPingException +from sagemaker.serve.utils.exceptions import InProcessDeepPingException from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer from sagemaker.session import Session @@ -46,7 +47,7 @@ def __init__( self.session = session self.schema_builder = schema_builder self.model_server = model_server - self._ping_container = None + self._ping_local_server = None def load(self, model_path: str = None): """Loads model path, checks that path exists""" @@ -69,21 +70,29 @@ def create_server( logger.info("Waiting for model server %s to start up...", self.model_server) if self.model_server == ModelServer.MMS: - self._ping_container = self._multi_model_server_deep_ping + self._ping_local_server = self._multi_model_server_deep_ping + self._start_serving() + + # allow some time for server to be ready. + time.sleep(1) time_limit = datetime.now() + timedelta(seconds=5) - while self._ping_container is not None: + healthy = True + while True: final_pull = datetime.now() > time_limit - if final_pull: break - time.sleep(10) - - healthy, response = self._ping_container(predictor) + healthy, response = self._ping_local_server(predictor) if healthy: logger.debug("Ping health check has passed. Returned %s", str(response)) break + time.sleep(1) + if not healthy: - raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) + raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) + + def destroy_server(self): + """Placeholder docstring""" + self._stop_serving() diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index b957186b99..69d5d2e5e7 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -2,12 +2,16 @@ from __future__ import absolute_import +import json + import requests import logging import platform from pathlib import Path + from sagemaker import Session, fw_utils from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.serve.utils.exceptions import InProcessDeepPingException from sagemaker.base_predictor import PredictorBase from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader @@ -25,16 +29,55 @@ class InProcessMultiModelServer: def _start_serving(self): """Initializes the start of the server""" - return Exception("Not implemented") + from sagemaker.serve.app import InProcessServer - def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): - """Invokes the MMS server by sending POST request""" - return Exception("Not implemented") + if hasattr(self, "inference_spec"): + model_id = self.inference_spec.get_model() + if not model_id: + raise ValueError("Model id was not provided in Inference Spec.") + else: + model_id = None + self.server = InProcessServer(model_id=model_id) + + self.server.start_server() + + def _stop_serving(self): + """Stops the server""" + self.server.stop_server() + + def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str): + """Placeholder docstring""" + try: + response = requests.post( + f"http://{self.server.host}:{self.server.port}/generate", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=600, + ) + response.raise_for_status() + if isinstance(response.content, bytes): + return json.loads(response.content.decode("utf-8")) + return response.content + except Exception as e: + if "Connection refused" in str(e): + raise Exception( + "Unable to send request to the local server: Connection refused." + ) from e + raise Exception("Unable to send request to the local server.") from e def _multi_model_server_deep_ping(self, predictor: PredictorBase): """Sends a deep ping to ensure prediction""" + healthy = False response = None - return (True, response) + try: + response = predictor.predict(self.schema_builder.sample_input) + healthy = response is not None + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise InProcessDeepPingException(str(e)) + + return healthy, response class LocalMultiModelServer: diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index be6133e8e1..89ec2253f1 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -3,7 +3,7 @@ from __future__ import absolute_import import io from typing import Type - +import logging from sagemaker import Session from sagemaker.serve.mode.local_container_mode import LocalContainerMode from sagemaker.serve.mode.in_process_mode import InProcessMode @@ -16,6 +16,8 @@ APPLICATION_X_NPY = "application/x-npy" +logger = logging.getLogger(__name__) + class TorchServeLocalPredictor(PredictorBase): """Lightweight predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes""" @@ -211,7 +213,7 @@ def delete_predictor(self): class TransformersInProcessModePredictor(PredictorBase): - """Lightweight Transformers predictor for local deployment""" + """Lightweight Transformers predictor for in process mode deployment""" def __init__( self, @@ -225,18 +227,11 @@ def __init__( def predict(self, data): """Placeholder docstring""" - return [ - self.deserializer.deserialize( - io.BytesIO( - self._mode_obj._invoke_multi_model_server_serving( - self.serializer.serialize(data), - self.content_type, - self.deserializer.ACCEPT[0], - ) - ), - self.content_type, - ) - ] + return self._mode_obj._invoke_multi_model_server_serving( + self.serializer.serialize(data), + self.content_type, + self.deserializer.ACCEPT[0], + ) @property def content_type(self): diff --git a/tests/unit/sagemaker/serve/builder/test_requirements_manager.py b/tests/unit/sagemaker/serve/builder/test_requirements_manager.py index 02833c81c0..b6886ab0a6 100644 --- a/tests/unit/sagemaker/serve/builder/test_requirements_manager.py +++ b/tests/unit/sagemaker/serve/builder/test_requirements_manager.py @@ -29,7 +29,7 @@ class TestRequirementsManager(unittest.TestCase): @patch( "sagemaker.serve.builder.requirements_manager.RequirementsManager._detect_conda_env_and_local_dependencies" ) - def test_capture_and_install_dependencies( + def test_capture_and_install_dependencies_txt( self, mock_detect_conda_env_and_local_dependencies, mock_install_requirements_txt, @@ -40,8 +40,7 @@ def test_capture_and_install_dependencies( RequirementsManager().capture_and_install_dependencies() mock_install_requirements_txt.assert_called_once() - mock_detect_conda_env_and_local_dependencies.side_effect = lambda: ".yml" - RequirementsManager().capture_and_install_dependencies() + RequirementsManager().capture_and_install_dependencies("conda.yml") mock_update_conda_env_in_path.assert_called_once() @patch( diff --git a/tests/unit/sagemaker/serve/mode/test_in_process_mode.py b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py index f5890982b9..0b1747029a 100644 --- a/tests/unit/sagemaker/serve/mode/test_in_process_mode.py +++ b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py @@ -18,7 +18,7 @@ from sagemaker.serve.mode.in_process_mode import InProcessMode from sagemaker.serve import SchemaBuilder from sagemaker.serve.utils.types import ModelServer -from sagemaker.serve.utils.exceptions import LocalDeepPingException +from sagemaker.serve.utils.exceptions import InProcessDeepPingException mock_prompt = "Hello, I'm a language model," @@ -98,6 +98,12 @@ def test_load_ex(self, mock_session, mock_inference_spec, mock_path): def test_create_server_happy( self, mock_session, mock_inference_spec, mock_predictor, mock_logger ): + mock_start_serving = Mock() + mock_start_serving.side_effect = lambda *args, **kwargs: ( + True, + None, + ) + mock_response = "Fake response" mock_multi_model_server_deep_ping = Mock() mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( @@ -114,6 +120,7 @@ def test_create_server_happy( ) in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping + in_process_mode._start_serving = mock_start_serving in_process_mode.create_server(predictor=mock_predictor) @@ -133,6 +140,12 @@ def test_create_server_ex( mock_inference_spec, mock_predictor, ): + mock_start_serving = Mock() + mock_start_serving.side_effect = lambda *args, **kwargs: ( + True, + None, + ) + mock_multi_model_server_deep_ping = Mock() mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( False, @@ -148,5 +161,29 @@ def test_create_server_ex( ) in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping + in_process_mode._start_serving = mock_start_serving + + self.assertRaises(InProcessDeepPingException, in_process_mode.create_server, mock_predictor) + + @patch( + "sagemaker.serve.model_server.multi_model_server.server.InProcessMultiModelServer._stop_serving" + ) + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_destroy_server( + self, + mock_session, + mock_inference_spec, + mock_stop_serving, + ): + in_process_mode = InProcessMode( + model_server=ModelServer.MMS, + inference_spec=mock_inference_spec, + schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), + session=mock_session, + model_path="model_path", + ) + + in_process_mode.destroy_server() - self.assertRaises(LocalDeepPingException, in_process_mode.create_server, mock_predictor) + mock_stop_serving.assert_called() diff --git a/tests/unit/sagemaker/serve/test_app.py b/tests/unit/sagemaker/serve/test_app.py new file mode 100644 index 0000000000..5b9a2218bb --- /dev/null +++ b/tests/unit/sagemaker/serve/test_app.py @@ -0,0 +1,91 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +import pytest + +from unittest.mock import patch, Mock +from sagemaker.serve.app import InProcessServer +from tests.integ.sagemaker.serve.constants import ( + PYTHON_VERSION_IS_NOT_310, +) + +mock_model_id = "mock_model_id" + + +class TestAppInProcessServer(unittest.TestCase): + @pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these tests are to test the serving components of our feature", + ) + @patch("sagemaker.serve.app.threading") + @patch("sagemaker.serve.app.pipeline") + def test_in_process_server_init(self, mock_pipeline, mock_threading): + mock_generator = Mock() + mock_generator.side_effect = None + + in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server._generator = mock_generator + + @pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", + ) + @patch("sagemaker.serve.app.logger") + @patch("sagemaker.serve.app.threading") + @patch("sagemaker.serve.app.pipeline") + def test_start_server(self, mock_pipeline, mock_threading, mock_logger): + mock_generator = Mock() + mock_generator.side_effect = None + mock_thread = Mock() + mock_threading.Thread.return_value = mock_thread + + in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server._generator = mock_generator + + in_process_server.start_server() + + mock_logger.info.assert_called() + mock_thread.start.assert_called() + + @pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", + ) + @patch("sagemaker.serve.app.asyncio") + @patch("sagemaker.serve.app.pipeline") + def test_start_run_async_in_thread(self, mock_pipeline, mock_asyncio): + mock_pipeline.side_effect = None + + mock_loop = Mock() + mock_asyncio.new_event_loop.side_effect = lambda: mock_loop + + in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server._start_run_async_in_thread() + + mock_asyncio.set_event_loop.assert_called_once_with(mock_loop) + mock_loop.run_until_complete.assert_called() + + @patch("sagemaker.serve.app.pipeline") + async def test_serve(self, mock_pipeline): + mock_pipeline.side_effect = None + + mock_server = Mock() + + in_process_server = InProcessServer(model_id=mock_model_id) + in_process_server.server = mock_server + + await in_process_server._serve() + + mock_server.serve.assert_called()