Skip to content

Commit a0f43b3

Browse files
author
Bryannah Hernandez
committed
minor changes
1 parent 84f1808 commit a0f43b3

File tree

4 files changed

+131
-118
lines changed

4 files changed

+131
-118
lines changed

src/sagemaker/serve/app.py

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,100 @@
11
"""FastAPI requests"""
22

33
from __future__ import absolute_import
4+
5+
import asyncio
46
import logging
7+
import threading
8+
from typing import Optional
59

610

711
logger = logging.getLogger(__name__)
812

913

1014
try:
1115
import uvicorn
12-
1316
except ImportError:
14-
logger.error("To enable in_process mode for Transformers install uvicorn from HuggingFace hub")
17+
logger.error("Unable to import uvicorn, check if uvicorn is installed.")
1518

1619

1720
try:
1821
from transformers import pipeline
19-
20-
generator = pipeline("text-generation", model="gpt2")
21-
2222
except ImportError:
23-
logger.error(
24-
"To enable in_process mode for Transformers install transformers from HuggingFace hub"
25-
)
23+
logger.error("Unable to import transformers, check if transformers is installed.")
2624

2725

2826
try:
29-
from fastapi import FastAPI, Request
30-
31-
app = FastAPI(
32-
title="Transformers In Process Server",
33-
version="1.0",
34-
description="A simple server",
35-
)
36-
37-
@app.get("/")
38-
def read_root():
27+
from fastapi import FastAPI, Request, APIRouter
28+
except ImportError:
29+
logger.error("Unable to import fastapi, check if fastapi is installed.")
30+
31+
32+
class InProcessServer:
33+
"""Placeholder docstring"""
34+
35+
def __init__(self, model_id: Optional[str] = None, task: Optional[str] = None):
36+
self._thread = None
37+
self._loop = None
38+
self._stop_event = asyncio.Event()
39+
self._router = APIRouter()
40+
self._model_id = model_id
41+
self._task = task
42+
self.server = None
43+
self.port = None
44+
self.host = None
45+
# TODO: Pick up device automatically.
46+
self._generator = pipeline(task, model=model_id, device="cpu")
47+
48+
# pylint: disable=unused-variable
49+
@self._router.post("/generate")
50+
async def generate_text(prompt: Request):
51+
"""Placeholder docstring"""
52+
str_prompt = await prompt.json()
53+
str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt
54+
55+
generated_text = self._generator(
56+
str_prompt, max_length=30, num_return_sequences=1, truncation=True
57+
)
58+
return generated_text
59+
60+
self._create_server()
61+
62+
def _create_server(self):
3963
"""Placeholder docstring"""
40-
return {"Hello": "World"}
64+
app = FastAPI()
65+
app.include_router(self._router)
66+
67+
config = uvicorn.Config(
68+
app,
69+
host="127.0.0.1",
70+
port=9007,
71+
log_level="info",
72+
loop="asyncio",
73+
reload=True,
74+
use_colors=True,
75+
)
4176

42-
@app.get("/generate")
43-
async def generate_text(prompt: Request):
44-
"""Placeholder docstring"""
45-
str_prompt = await prompt.json()
77+
self.server = uvicorn.Server(config)
78+
self.host = config.host
79+
self.port = config.port
4680

47-
generated_text = generator(
48-
str_prompt, max_length=30, num_return_sequences=5, truncation=True
49-
)
50-
return generated_text[0]["generated_text"]
81+
def start_server(self):
82+
"""Starts the uvicorn server."""
83+
if not (self._thread and self._thread.is_alive()):
84+
logger.info("Waiting for a connection...")
85+
self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True)
86+
self._thread.start()
5187

52-
@app.post("/post")
53-
def post(payload: dict):
88+
def stop_server(self):
89+
"""Destroys the uvicorn server."""
90+
# TODO: Implement me.
91+
92+
def _start_run_async_in_thread(self):
5493
"""Placeholder docstring"""
55-
return payload
94+
loop = asyncio.new_event_loop()
95+
asyncio.set_event_loop(loop)
96+
loop.run_until_complete(self._serve())
5697

57-
except ImportError:
58-
logger.error("To enable in_process mode for Transformers install fastapi from HuggingFace hub")
59-
60-
61-
async def main():
62-
"""Running server locally with uvicorn"""
63-
config = uvicorn.Config(
64-
"sagemaker.serve.app:app",
65-
host="127.0.0.1",
66-
port=9007,
67-
log_level="info",
68-
loop="asyncio",
69-
reload=True,
70-
workers=3,
71-
use_colors=True,
72-
)
73-
server = uvicorn.Server(config)
74-
logger.info("Waiting for a connection...")
75-
await server.serve()
98+
async def _serve(self):
99+
"""Placeholder docstring"""
100+
await self.server.serve()

src/sagemaker/serve/mode/in_process_mode.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module that defines the InProcessMode class"""
22

33
from __future__ import absolute_import
4+
45
from pathlib import Path
56
import logging
67
from typing import Dict, Type
@@ -11,7 +12,7 @@
1112
from sagemaker.serve.spec.inference_spec import InferenceSpec
1213
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1314
from sagemaker.serve.utils.types import ModelServer
14-
from sagemaker.serve.utils.exceptions import LocalDeepPingException
15+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
1516
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
1617
from sagemaker.session import Session
1718

@@ -46,7 +47,7 @@ def __init__(
4647
self.session = session
4748
self.schema_builder = schema_builder
4849
self.model_server = model_server
49-
self._ping_container = None
50+
self._ping_local_server = None
5051

5152
def load(self, model_path: str = None):
5253
"""Loads model path, checks that path exists"""
@@ -69,22 +70,29 @@ def create_server(
6970
logger.info("Waiting for model server %s to start up...", self.model_server)
7071

7172
if self.model_server == ModelServer.MMS:
72-
self._ping_container = self._multi_model_server_deep_ping
73+
self._ping_local_server = self._multi_model_server_deep_ping
7374
self._start_serving()
7475

76+
# allow some time for server to be ready.
77+
time.sleep(1)
78+
7579
time_limit = datetime.now() + timedelta(seconds=5)
76-
while self._ping_container is not None:
80+
healthy = True
81+
while True:
7782
final_pull = datetime.now() > time_limit
78-
7983
if final_pull:
8084
break
8185

82-
time.sleep(10)
83-
84-
healthy, response = self._ping_container(predictor)
86+
healthy, response = self._ping_local_server(predictor)
8587
if healthy:
8688
logger.debug("Ping health check has passed. Returned %s", str(response))
8789
break
8890

91+
time.sleep(1)
92+
8993
if not healthy:
90-
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
94+
raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
95+
96+
def destroy_server(self):
97+
"""Placeholder docstring"""
98+
self._stop_serving()

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from __future__ import absolute_import
44

5-
import asyncio
5+
import json
6+
67
import requests
78
import logging
89
import platform
9-
import time
1010
from pathlib import Path
11+
1112
from sagemaker import Session, fw_utils
1213
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
1314
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
@@ -26,78 +27,57 @@
2627
class InProcessMultiModelServer:
2728
"""In Process Mode Multi Model server instance"""
2829

29-
def __init__(self):
30-
from sagemaker.serve.app import main
31-
32-
self._main = main
33-
3430
def _start_serving(self):
3531
"""Initializes the start of the server"""
36-
background_tasks = set()
37-
task = asyncio.create_task(self._main())
38-
background_tasks.add(task)
39-
task.add_done_callback(background_tasks.discard)
32+
from sagemaker.serve.app import InProcessServer
4033

41-
time.sleep(10)
34+
if hasattr(self, "inference_spec"):
35+
model_id = self.inference_spec.get_model()
36+
if not model_id:
37+
raise ValueError("Model id was not provided in Inference Spec.")
38+
else:
39+
model_id = None
40+
self.server = InProcessServer(model_id=model_id)
4241

43-
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
42+
self.server.start_server()
43+
44+
def _stop_serving(self):
45+
"""Stops the server"""
46+
self.server.stop_server()
47+
48+
def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str):
4449
"""Placeholder docstring"""
45-
time.sleep(2)
46-
background_tasks = set()
47-
task = asyncio.create_task(self.generate_connect())
48-
background_tasks.add(task)
49-
task.add_done_callback(background_tasks.discard)
50-
return task.result()
50+
try:
51+
response = requests.post(
52+
f"http://{self.server.host}:{self.server.port}/generate",
53+
data=request,
54+
headers={"Content-Type": content_type, "Accept": accept},
55+
timeout=600,
56+
)
57+
response.raise_for_status()
58+
if isinstance(response.content, bytes):
59+
return json.loads(response.content.decode("utf-8"))
60+
return response.content
61+
except Exception as e:
62+
if "Connection refused" in str(e):
63+
raise Exception(
64+
"Unable to send request to the local server: Connection refused."
65+
) from e
66+
raise Exception("Unable to send request to the local server.") from e
5167

5268
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
5369
"""Sends a deep ping to ensure prediction"""
54-
background_tasks = set()
55-
task = asyncio.create_task(self.tcp_connect())
56-
background_tasks.add(task)
57-
task.add_done_callback(background_tasks.discard)
70+
healthy = False
5871
response = None
5972
try:
6073
response = predictor.predict(self.schema_builder.sample_input)
61-
return True, response
74+
healthy = response is not None
6275
# pylint: disable=broad-except
6376
except Exception as e:
6477
if "422 Client Error: Unprocessable Entity for url" in str(e):
6578
raise InProcessDeepPingException(str(e))
66-
return False, response
67-
68-
async def generate_connect(self):
69-
"""Writes the lines in bytes for server"""
70-
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
71-
a = (
72-
b"GET /generate HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: "
73-
b"python-requests/2.31.0\r\nAccept-Encoding: gzip, deflate, br\r\nAccept: */*\r\nConnection: ",
74-
"keep-alive\r\nContent-Length: 33\r\nContent-Type: application/json\r\n\r\n",
75-
)
76-
b = b'"\\"Hello, I\'m a language model\\""'
77-
list = [a, b]
78-
writer.writelines(list)
79-
80-
data = await reader.read()
81-
logger.info("Response from server")
82-
logger.info(data)
83-
writer.close()
84-
await writer.wait_closed()
85-
return data
86-
87-
async def tcp_connect(self):
88-
"""Writes the lines in bytes for server"""
89-
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
90-
writer.write(
91-
b"GET / HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: python-requests/2.32.3\r\nAccept-Encoding: gzip, ",
92-
"deflate, br\r\nAccept: */*\r\nConnection: keep-alive\r\n\r\n",
93-
)
9479

95-
data = await reader.read()
96-
logger.info("Response from server")
97-
logger.info(data)
98-
writer.close()
99-
await writer.wait_closed()
100-
return data
80+
return healthy, response
10181

10282

10383
class LocalMultiModelServer:

tests/unit/sagemaker/serve/mode/test_in_process_mode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from sagemaker.serve.mode.in_process_mode import InProcessMode
1919
from sagemaker.serve import SchemaBuilder
2020
from sagemaker.serve.utils.types import ModelServer
21-
from sagemaker.serve.utils.exceptions import LocalDeepPingException
21+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
2222

2323

2424
mock_prompt = "Hello, I'm a language model,"
@@ -163,4 +163,4 @@ def test_create_server_ex(
163163
in_process_mode._multi_model_server_deep_ping = mock_multi_model_server_deep_ping
164164
in_process_mode._start_serving = mock_start_serving
165165

166-
self.assertRaises(LocalDeepPingException, in_process_mode.create_server, mock_predictor)
166+
self.assertRaises(InProcessDeepPingException, in_process_mode.create_server, mock_predictor)

0 commit comments

Comments
 (0)