Skip to content

Commit 499063d

Browse files
author
Bryannah Hernandez
committed
FastAPI with In_Process
1 parent 68000e1 commit 499063d

File tree

4 files changed

+85
-4
lines changed

4 files changed

+85
-4
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
accelerate>=0.24.1,<=0.27.0
22
sagemaker_schema_inference_artifacts>=0.0.5
3+
uvicorn>=0.30.1
4+
fastapi>=0.111.0
5+
nest-asyncio

src/sagemaker/app.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from __future__ import absolute_import
2+
3+
import asyncio
4+
import logging
5+
from transformers import pipeline
6+
from fastapi import FastAPI
7+
import uvicorn
8+
9+
logger = logging.getLogger(__name__)
10+
11+
app = FastAPI(
12+
title="Transformers In Process Server",
13+
version="1.0",
14+
description="A simple server",
15+
)
16+
17+
18+
@app.get("/")
19+
def read_root():
20+
return {"Hello": "World"}
21+
22+
23+
@app.post("/generate")
24+
def generate_text(prompt: str, max_length=500, num_return_sequences=1):
25+
logger.info("Generating Text....")
26+
27+
generated_text = generator(
28+
prompt, max_length=max_length, num_return_sequences=num_return_sequences
29+
)
30+
return generated_text[0]["generated_text"]
31+
32+
33+
generator = pipeline("text-generation", model="gpt2")
34+
35+
36+
@app.post("/post")
37+
def post(prompt: str):
38+
return prompt
39+
40+
41+
async def main():
42+
logger.info("Running")
43+
config = uvicorn.Config(
44+
"sagemaker.app:app", host="0.0.0.0", port=8080, log_level="info", loop="asyncio"
45+
)
46+
server = uvicorn.Server(config)
47+
await server.serve()
48+

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import absolute_import
44

5+
import asyncio
56
import requests
67
import logging
78
import platform
@@ -13,6 +14,8 @@
1314
from sagemaker.s3 import S3Uploader
1415
from sagemaker.local.utils import get_docker_host
1516
from sagemaker.serve.utils.optimize_utils import _is_s3_uri
17+
import time
18+
from sagemaker.app import main
1619

1720
MODE_DIR_BINDING = "/opt/ml/model/"
1821
_DEFAULT_ENV_VARS = {}
@@ -25,11 +28,36 @@ class InProcessMultiModelServer:
2528

2629
def _start_serving(self):
2730
"""Initializes the start of the server"""
28-
return Exception("Not implemented")
31+
32+
logger.info("Server started at http://0.0.0.0")
33+
34+
asyncio.create_task(main())
2935

3036
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
3137
"""Invokes the MMS server by sending POST request"""
32-
return Exception("Not implemented")
38+
39+
logger.info(request)
40+
logger.info(content_type)
41+
logger.info(accept)
42+
43+
try:
44+
response = requests.post(
45+
f"http://0.0.0.0:8080/generate",
46+
data=request,
47+
headers={"Content-Type": content_type, "Accept": accept},
48+
timeout=600,
49+
)
50+
response.raise_for_status()
51+
52+
return response.content
53+
except requests.exceptions.ConnectionError as e:
54+
logger.debug(f"Error connecting to the server: {e}")
55+
except requests.exceptions.HTTPError as e:
56+
logger.debug(f"HTTP error occurred: {e}")
57+
except requests.exceptions.RequestException as e:
58+
logger.debug(f"An error occurred: {e}")
59+
except Exception as e:
60+
raise Exception("Unable to send request to the local container server") from e
3361

3462
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
3563
"""Sends a deep ping to ensure prediction"""

src/sagemaker/serve/utils/predictors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from __future__ import absolute_import
44
import io
55
from typing import Type
6-
6+
import logging
7+
import json
78
from sagemaker import Session
89
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
910
from sagemaker.serve.mode.in_process_mode import InProcessMode
@@ -16,6 +17,7 @@
1617

1718
APPLICATION_X_NPY = "application/x-npy"
1819

20+
logger = logging.getLogger(__name__)
1921

2022
class TorchServeLocalPredictor(PredictorBase):
2123
"""Lightweight predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes"""
@@ -211,7 +213,7 @@ def delete_predictor(self):
211213

212214

213215
class TransformersInProcessModePredictor(PredictorBase):
214-
"""Lightweight Transformers predictor for local deployment"""
216+
"""Lightweight Transformers predictor for in process mode deployment"""
215217

216218
def __init__(
217219
self,

0 commit comments

Comments
 (0)