Skip to content

Commit e89940e

Browse files
author
Jonathan Makunga
committed
In Process Mode
1 parent 7fb02cf commit e89940e

File tree

3 files changed

+153
-114
lines changed

3 files changed

+153
-114
lines changed

src/sagemaker/serve/app.py

Lines changed: 85 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,109 @@
11
"""FastAPI requests"""
22

33
from __future__ import absolute_import
4+
5+
import asyncio
46
import logging
7+
import threading
8+
from typing import Optional
59

610

711
logger = logging.getLogger(__name__)
812

913

1014
try:
1115
import uvicorn
12-
1316
except ImportError:
14-
logger.error("To enable in_process mode for Transformers install uvicorn from HuggingFace hub")
17+
logger.error("Unable to import uvicorn, check if uvicorn is installed.")
1518

1619

1720
try:
1821
from transformers import pipeline
19-
20-
generator = pipeline("text-generation", model="gpt2")
21-
2222
except ImportError:
2323
logger.error(
24-
"To enable in_process mode for Transformers install transformers from HuggingFace hub"
24+
"Unable to import transformers, check if transformers is installed."
2525
)
2626

2727

2828
try:
29-
from fastapi import FastAPI, Request
30-
31-
app = FastAPI(
32-
title="Transformers In Process Server",
33-
version="1.0",
34-
description="A simple server",
35-
)
36-
37-
@app.get("/")
38-
def read_root():
39-
"""Placeholder docstring"""
40-
return {"Hello": "World"}
41-
42-
@app.get("/generate")
43-
async def generate_text(prompt: Request):
44-
"""Placeholder docstring"""
45-
str_prompt = await prompt.json()
46-
47-
generated_text = generator(
48-
str_prompt, max_length=30, num_return_sequences=5, truncation=True
29+
from fastapi import FastAPI, Request, APIRouter
30+
except ImportError:
31+
logger.error("Unable to import fastapi, check if fastapi is installed.")
32+
33+
34+
class InProcessServer:
35+
36+
def __init__(
37+
self,
38+
model_id: Optional[str] = None,
39+
task: Optional[str] = None
40+
):
41+
self._thread = None
42+
self._loop = None
43+
self._stop_event = asyncio.Event()
44+
self._router = APIRouter()
45+
self._model_id = model_id
46+
self._task = task
47+
self.server = None
48+
self.port = None
49+
self.host = None
50+
51+
self._generator = pipeline(task, model=model_id, device="cpu")
52+
53+
@self._router.post("/generate")
54+
async def generate_text(prompt: Request):
55+
"""Placeholder docstring"""
56+
str_prompt = await prompt.json()
57+
str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt
58+
59+
generated_text = self._generator(
60+
str_prompt, max_length=30, num_return_sequences=1, truncation=True
61+
)
62+
return generated_text
63+
64+
self._create_server()
65+
66+
def _create_server(self):
67+
_app = FastAPI()
68+
_app.include_router(self._router)
69+
70+
config = uvicorn.Config(
71+
_app,
72+
host="127.0.0.1",
73+
port=9007,
74+
log_level="info",
75+
loop="asyncio",
76+
reload=True,
77+
workers=3,
78+
use_colors=True,
4979
)
50-
return generated_text[0]["generated_text"]
5180

52-
@app.post("/post")
53-
def post(payload: dict):
54-
"""Placeholder docstring"""
55-
return payload
56-
57-
except ImportError:
58-
logger.error("To enable in_process mode for Transformers install fastapi from HuggingFace hub")
59-
60-
61-
async def main():
62-
"""Running server locally with uvicorn"""
63-
config = uvicorn.Config(
64-
"sagemaker.serve.app:app",
65-
host="127.0.0.1",
66-
port=9007,
67-
log_level="info",
68-
loop="asyncio",
69-
reload=True,
70-
workers=3,
71-
use_colors=True,
72-
)
73-
server = uvicorn.Server(config)
74-
logger.info("Waiting for a connection...")
75-
await server.serve()
81+
self.server = uvicorn.Server(config)
82+
self.host = config.host
83+
self.port = config.port
84+
85+
def start_server(self):
86+
"""Starts the uvicorn server."""
87+
if not (self._thread and self._thread.is_alive()):
88+
logger.info("Waiting for a connection...")
89+
self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True)
90+
self._thread.start()
91+
92+
def stop_server(self):
93+
"""Destroys the uvicorn server."""
94+
if self.is_running:
95+
logger.info("Deleting server...")
96+
# self._stop_event.set()
97+
# self._thread.join()
98+
logger.info("Server deleted.")
99+
100+
def _start_run_async_in_thread(self):
101+
loop = asyncio.new_event_loop()
102+
asyncio.set_event_loop(loop)
103+
loop.run_until_complete(self._serve())
104+
105+
async def _serve(self):
106+
await self.server.serve()
107+
108+
def is_running(self):
109+
return self._thread is not None and self._thread.is_alive()

src/sagemaker/serve/mode/in_process_mode.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module that defines the InProcessMode class"""
22

33
from __future__ import absolute_import
4+
45
from pathlib import Path
56
import logging
67
from typing import Dict, Type
@@ -11,7 +12,7 @@
1112
from sagemaker.serve.spec.inference_spec import InferenceSpec
1213
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1314
from sagemaker.serve.utils.types import ModelServer
14-
from sagemaker.serve.utils.exceptions import LocalDeepPingException
15+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
1516
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
1617
from sagemaker.session import Session
1718

@@ -46,7 +47,7 @@ def __init__(
4647
self.session = session
4748
self.schema_builder = schema_builder
4849
self.model_server = model_server
49-
self._ping_container = None
50+
self._ping_local_server = None
5051

5152
def load(self, model_path: str = None):
5253
"""Loads model path, checks that path exists"""
@@ -69,22 +70,30 @@ def create_server(
6970
logger.info("Waiting for model server %s to start up...", self.model_server)
7071

7172
if self.model_server == ModelServer.MMS:
72-
self._ping_container = self._multi_model_server_deep_ping
73+
self._ping_local_server = self._multi_model_server_deep_ping
7374
self._start_serving()
7475

75-
time_limit = datetime.now() + timedelta(seconds=5)
76+
# allow some time for server to be ready.
77+
time.sleep(1)
78+
79+
count = 1
80+
time_limit = datetime.now() + timedelta(seconds=20)
81+
healthy = True
7682
while True:
7783
final_pull = datetime.now() > time_limit
78-
7984
if final_pull:
8085
break
8186

82-
time.sleep(10)
83-
84-
healthy, response = self._ping_container(predictor)
87+
healthy, response = self._ping_local_server(predictor)
88+
count += 1
8589
if healthy:
8690
logger.debug("Ping health check has passed. Returned %s", str(response))
8791
break
8892

93+
time.sleep(1)
94+
8995
if not healthy:
90-
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
96+
raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
97+
98+
def destroy_server(self):
99+
self._stop_serving()

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
from __future__ import absolute_import
44

55
import asyncio
6+
import json
7+
import threading
8+
69
import requests
710
import logging
811
import platform
9-
import time
1012
from pathlib import Path
13+
1114
from sagemaker import Session, fw_utils
1215
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
1316
from sagemaker.base_predictor import PredictorBase
@@ -26,70 +29,63 @@ class InProcessMultiModelServer:
2629
"""In Process Mode Multi Model server instance"""
2730

2831
def __init__(self):
29-
from sagemaker.serve.app import main
32+
# from sagemaker.serve.app import InProcessServer
33+
# self._in_process_server = InProcessServer
34+
pass
3035

31-
self._main = main
36+
# def run_async_in_thread(self):
37+
# loop = asyncio.new_event_loop()
38+
# asyncio.set_event_loop(loop)
39+
# loop.run_until_complete(self.main())
3240

3341
def _start_serving(self):
3442
"""Initializes the start of the server"""
35-
background_tasks = set()
36-
task = asyncio.create_task(self._main())
37-
background_tasks.add(task)
38-
task.add_done_callback(background_tasks.discard)
43+
from sagemaker.serve.app import InProcessServer
44+
# threading.Thread(target=self.run_async_in_thread, daemon=True).start()
45+
if hasattr(self, "inference_spec"):
46+
model_id = self.inference_spec.get_model()
47+
else:
48+
model_id = None
49+
self.server = InProcessServer(model_id=model_id)
3950

40-
time.sleep(10)
51+
self.server.start_server()
4152

42-
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
53+
def _stop_serving(self):
54+
"""Stops the server"""
55+
self.server.stop_server()
56+
57+
def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str):
4358
"""Placeholder docstring"""
44-
background_tasks = set()
45-
task = asyncio.create_task(self.generate_connect())
46-
background_tasks.add(task)
47-
task.add_done_callback(background_tasks.discard)
59+
try:
60+
response = requests.post(
61+
# "http://127.0.0.1:9007/generate",
62+
f"http://{self.server.host}:{self.server.port}/generate",
63+
data=request,
64+
headers={"Content-Type": content_type, "Accept": accept},
65+
timeout=600,
66+
)
67+
response.raise_for_status()
68+
if isinstance(response.content, bytes):
69+
return json.loads(response.content.decode('utf-8'))
70+
return response.content
71+
except Exception as e:
72+
if not "Connection refused" in str(e):
73+
raise Exception("Unable to send request to the local server: Connection refused.") from e
74+
raise Exception("Unable to send request to the local server.") from e
4875

4976
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
5077
"""Sends a deep ping to ensure prediction"""
51-
background_tasks = set()
52-
task = asyncio.create_task(self.tcp_connect())
53-
background_tasks.add(task)
54-
task.add_done_callback(background_tasks.discard)
78+
healthy = False
5579
response = None
56-
return True, response
57-
58-
async def generate_connect(self):
59-
"""Writes the lines in bytes for server"""
60-
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
61-
a = (
62-
b"GET /generate HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: "
63-
b"python-requests/2.31.0\r\nAccept-Encoding: gzip, deflate, br\r\nAccept: */*\r\nConnection: ",
64-
"keep-alive\r\nContent-Length: 33\r\nContent-Type: application/json\r\n\r\n",
65-
)
66-
b = b'"\\"Hello, I\'m a language model\\""'
67-
list = [a, b]
68-
writer.writelines(list)
69-
logger.debug(writer.get_extra_info("peername"))
70-
logger.debug(writer.transport)
71-
72-
data = await reader.read()
73-
logger.info("Response from server")
74-
logger.info(data)
75-
writer.close()
76-
await writer.wait_closed()
77-
78-
async def tcp_connect(self):
79-
"""Writes the lines in bytes for server"""
80-
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
81-
writer.write(
82-
b"GET / HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: python-requests/2.32.3\r\nAccept-Encoding: gzip, ",
83-
"deflate, br\r\nAccept: */*\r\nConnection: keep-alive\r\n\r\n",
84-
)
85-
logger.debug(writer.get_extra_info("peername"))
86-
logger.debug(writer.transport)
87-
88-
data = await reader.read()
89-
logger.info("Response from server")
90-
logger.info(data)
91-
writer.close()
92-
await writer.wait_closed()
80+
try:
81+
response = predictor.predict(self.schema_builder.sample_input)
82+
healthy = response is not None
83+
# pylint: disable=broad-except
84+
except Exception as e:
85+
if "422 Client Error: Unprocessable Entity for url" in str(e):
86+
raise LocalModelInvocationException(str(e))
87+
88+
return healthy, response
9389

9490

9591
class LocalMultiModelServer:

0 commit comments

Comments
 (0)