4
4
5
5
import logging
6
6
from transformers import pipeline
7
- from fastapi import FastAPI
7
+ from fastapi import FastAPI , Request
8
8
import uvicorn
9
9
10
10
logger = logging .getLogger (__name__ )
@@ -22,31 +22,33 @@ def read_root():
22
22
return {"Hello" : "World" }
23
23
24
24
25
- @app .post ("/generate" )
26
- def generate_text (prompt : str , max_length = 500 , num_return_sequences = 1 ):
25
+ @app .get ("/generate" )
26
+ async def generate_text (prompt : Request ):
27
27
"""Placeholder docstring"""
28
28
logger .info ("Generating Text...." )
29
29
30
- generated_text = generator (
31
- prompt , max_length = max_length , num_return_sequences = num_return_sequences
32
- )
30
+ str_prompt = await prompt .json ()
31
+
32
+ logger .info (str_prompt )
33
+
34
+ generated_text = generator (str_prompt , max_length = 30 , num_return_sequences = 5 , truncation = True )
33
35
return generated_text [0 ]["generated_text" ]
34
36
35
37
36
38
generator = pipeline ("text-generation" , model = "gpt2" )
37
39
38
40
39
41
@app .post ("/post" )
40
- def post (prompt : str ):
42
+ def post (payload : dict ):
41
43
"""Placeholder docstring"""
42
- return prompt
44
+ return payload
43
45
44
46
45
47
async def main ():
46
48
"""Running server locally with uvicorn"""
47
49
logger .info ("Running" )
48
50
config = uvicorn .Config (
49
- "sagemaker.app:app" , host = "0.0.0.0" , port = 8080 , log_level = "info" , loop = "asyncio"
51
+ "sagemaker.app:app" , host = "0.0.0.0" , port = 8000 , log_level = "info" , loop = "asyncio"
50
52
)
51
53
server = uvicorn .Server (config )
52
54
await server .serve ()
0 commit comments