Skip to content

Commit 6f34c61

Browse files
author
Bryannah Hernandez
committed
changes
1 parent 7e17631 commit 6f34c61

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/sagemaker/app.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
from transformers import pipeline
7-
from fastapi import FastAPI
7+
from fastapi import FastAPI, Request
88
import uvicorn
99

1010
logger = logging.getLogger(__name__)
@@ -22,31 +22,33 @@ def read_root():
2222
return {"Hello": "World"}
2323

2424

25-
@app.post("/generate")
26-
def generate_text(prompt: str, max_length=500, num_return_sequences=1):
25+
@app.get("/generate")
26+
async def generate_text(prompt: Request):
2727
"""Placeholder docstring"""
2828
logger.info("Generating Text....")
2929

30-
generated_text = generator(
31-
prompt, max_length=max_length, num_return_sequences=num_return_sequences
32-
)
30+
str_prompt = await prompt.json()
31+
32+
logger.info(str_prompt)
33+
34+
generated_text = generator(str_prompt, max_length=30, num_return_sequences=5, truncation=True)
3335
return generated_text[0]["generated_text"]
3436

3537

3638
generator = pipeline("text-generation", model="gpt2")
3739

3840

3941
@app.post("/post")
40-
def post(prompt: str):
42+
def post(payload: dict):
4143
"""Placeholder docstring"""
42-
return prompt
44+
return payload
4345

4446

4547
async def main():
4648
"""Running server locally with uvicorn"""
4749
logger.info("Running")
4850
config = uvicorn.Config(
49-
"sagemaker.app:app", host="0.0.0.0", port=8080, log_level="info", loop="asyncio"
51+
"sagemaker.app:app", host="0.0.0.0", port=8000, log_level="info", loop="asyncio"
5052
)
5153
server = uvicorn.Server(config)
5254
await server.serve()

0 commit comments

Comments
 (0)