Skip to content

Commit 4e6bd26

Browse files
author
Bryannah Hernandez
committed
fastapi predictor fix
1 parent 7939237 commit 4e6bd26

File tree

5 files changed

+58
-63
lines changed

5 files changed

+58
-63
lines changed

src/sagemaker/serve/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def main():
6262
config = uvicorn.Config(
6363
"sagemaker.app:app",
6464
host="127.0.0.1",
65-
port=8080,
65+
port=9007,
6666
log_level="info",
6767
loop="asyncio",
6868
reload=True,

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,10 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
237237
)
238238

239239
self.modes[str(Mode.IN_PROCESS)].create_server(
240-
self.image_uri,
241-
timeout if timeout else DEFAULT_TIMEOUT,
242-
None,
243240
predictor,
244-
self.pysdk_model.env,
245241
)
246242
return predictor
247-
243+
248244
self._set_instance(kwargs)
249245

250246
if "mode" in kwargs:

src/sagemaker/serve/mode/in_process_mode.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,7 @@ def prepare(self):
6363

6464
def create_server(
6565
self,
66-
image: str,
67-
secret_key: str,
6866
predictor: PredictorBase,
69-
env_vars: Dict[str, str] = None,
70-
model_path: str = None,
7167
):
7268
"""Creating the server and checking ping health."""
7369
logger.info("Waiting for model server %s to start up...", self.model_server)

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import platform
99
import time
10-
import json
1110
from pathlib import Path
1211
from sagemaker import Session, fw_utils
1312
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
@@ -16,6 +15,7 @@
1615
from sagemaker.s3 import S3Uploader
1716
from sagemaker.local.utils import get_docker_host
1817
from sagemaker.serve.utils.optimize_utils import _is_s3_uri
18+
from sagemaker.serve.app import main
1919

2020
MODE_DIR_BINDING = "/opt/ml/model/"
2121
_DEFAULT_ENV_VARS = {}
@@ -28,55 +28,64 @@ class InProcessMultiModelServer:
2828

2929
def _start_serving(self):
3030
"""Initializes the start of the server"""
31-
from sagemaker.serve.app import main
32-
33-
asyncio.create_task(main())
31+
background_tasks = set()
32+
task = asyncio.create_task(main())
33+
background_tasks.add(task)
34+
task.add_done_callback(background_tasks.discard)
3435

3536
time.sleep(10)
3637

3738
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
3839
"""Placeholder docstring"""
39-
try: # for Python 3
40-
from http.client import HTTPConnection
41-
except ImportError:
42-
from httplib import HTTPConnection
43-
44-
HTTPConnection.debuglevel = 1
45-
logging.basicConfig() # you need to initialize logging, otherwise you will not see
46-
# anything from requests
47-
logging.getLogger().setLevel(logging.DEBUG)
48-
requests_log = logging.getLogger("urllib3")
49-
requests_log.setLevel(logging.DEBUG)
50-
requests_log.propagate = True
51-
52-
try:
53-
requests.get("http://127.0.0.1:8080/", verify=False).json()
54-
except Exception as ex:
55-
logger.error(ex)
56-
raise ex
57-
58-
try:
59-
response = requests.get(
60-
"http://127.0.0.1:8080/generate",
61-
json=json.dumps(request),
62-
headers={"Content-Type": content_type, "Accept": accept},
63-
timeout=600,
64-
).json()
65-
66-
return response
67-
except requests.exceptions.ConnectionError as e:
68-
logger.debug(f"Error connecting to the server: {e}")
69-
except requests.exceptions.HTTPError as e:
70-
logger.debug(f"HTTP error occurred: {e}")
71-
except requests.exceptions.RequestException as e:
72-
logger.debug(f"An error occurred: {e}")
73-
except Exception as e:
74-
raise Exception("Unable to send request to the local container server") from e
40+
background_tasks = set()
41+
task = asyncio.create_task(self.generate_connect())
42+
background_tasks.add(task)
43+
task.add_done_callback(background_tasks.discard)
7544

7645
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
7746
"""Sends a deep ping to ensure prediction"""
47+
background_tasks = set()
48+
task = asyncio.create_task(self.tcp_connect())
49+
background_tasks.add(task)
50+
task.add_done_callback(background_tasks.discard)
7851
response = None
79-
return (True, response)
52+
return True, response
53+
54+
async def generate_connect(self):
55+
"""Writes the lines in bytes for server"""
56+
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
57+
a = (
58+
b"GET /generate HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: "
59+
b"python-requests/2.31.0\r\nAccept-Encoding: gzip, deflate, br\r\nAccept: */*\r\nConnection: ",
60+
"keep-alive\r\nContent-Length: 33\r\nContent-Type: application/json\r\n\r\n",
61+
)
62+
b = b'"\\"Hello, I\'m a language model\\""'
63+
list = [a, b]
64+
writer.writelines(list)
65+
logger.debug(writer.get_extra_info("peername"))
66+
logger.debug(writer.transport)
67+
68+
data = await reader.read()
69+
logger.info("Response from server")
70+
logger.info(data)
71+
writer.close()
72+
await writer.wait_closed()
73+
74+
async def tcp_connect(self):
75+
"""Writes the lines in bytes for server"""
76+
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
77+
writer.write(
78+
b"GET / HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: python-requests/2.32.3\r\nAccept-Encoding: gzip, ",
79+
"deflate, br\r\nAccept: */*\r\nConnection: keep-alive\r\n\r\n",
80+
)
81+
logger.debug(writer.get_extra_info("peername"))
82+
logger.debug(writer.transport)
83+
84+
data = await reader.read()
85+
logger.info("Response from server")
86+
logger.info(data)
87+
writer.close()
88+
await writer.wait_closed()
8089

8190

8291
class LocalMultiModelServer:

src/sagemaker/serve/utils/predictors.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,18 +229,12 @@ def predict(self, data):
229229
"""Placeholder docstring"""
230230
logger.info("Entering predict to make a prediction on ")
231231
logger.info(data)
232-
return [
233-
self.deserializer.deserialize(
234-
io.BytesIO(
235-
self._mode_obj._invoke_multi_model_server_serving(
236-
self.serializer.serialize(data),
237-
self.content_type,
238-
self.deserializer.ACCEPT[0],
239-
)
240-
),
241-
self.content_type,
242-
)
243-
]
232+
233+
return self._mode_obj._invoke_multi_model_server_serving(
234+
self.serializer.serialize(data),
235+
self.content_type,
236+
self.deserializer.ACCEPT[0],
237+
)
244238

245239
@property
246240
def content_type(self):

0 commit comments

Comments
 (0)