1616import pyrootutils
1717import soundfile as sf
1818import torch
19- from kui .wsgi import (
19+ from kui .asgi import (
2020 Body ,
21+ FileResponse ,
2122 HTTPException ,
2223 HttpView ,
2324 JSONResponse ,
2425 Kui ,
2526 OpenAPI ,
2627 StreamResponse ,
2728)
28- from kui .wsgi .routing import MultimethodRoutes
29+ from kui .asgi .routing import MultimethodRoutes
2930from loguru import logger
3031from pydantic import BaseModel , Field
3132from transformers import AutoTokenizer
@@ -57,7 +58,7 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
5758
5859
5960# Define utils for web server
60- def http_execption_handler (exc : HTTPException ):
61+ async def http_execption_handler (exc : HTTPException ):
6162 return JSONResponse (
6263 dict (
6364 statusCode = exc .status_code ,
@@ -69,7 +70,7 @@ def http_execption_handler(exc: HTTPException):
6970 )
7071
7172
72- def other_exception_handler (exc : "Exception" ):
73+ async def other_exception_handler (exc : "Exception" ):
7374 traceback .print_exc ()
7475
7576 status = HTTPStatus .INTERNAL_SERVER_ERROR
@@ -334,8 +335,17 @@ def inference(req: InvokeRequest):
334335 yield fake_audios
335336
336337
338+ async def inference_async (req : InvokeRequest ):
339+ for chunk in inference (req ):
340+ yield chunk
341+
342+
343+ async def buffer_to_async_generator (buffer ):
344+ yield buffer
345+
346+
337347@routes .http .post ("/v1/invoke" )
338- def api_invoke_model (
348+ async def api_invoke_model (
339349 req : Annotated [InvokeRequest , Body (exclusive = True )],
340350):
341351 """
@@ -354,22 +364,21 @@ def api_invoke_model(
354364 content = "Streaming only supports WAV format" ,
355365 )
356366
357- generator = inference (req )
358367 if req .streaming :
359368 return StreamResponse (
360- iterable = generator ,
369+ iterable = inference_async ( req ) ,
361370 headers = {
362371 "Content-Disposition" : f"attachment; filename=audio.{ req .format } " ,
363372 },
364373 content_type = get_content_type (req .format ),
365374 )
366375 else :
367- fake_audios = next (generator )
376+ fake_audios = next (inference ( req ) )
368377 buffer = io .BytesIO ()
369378 sf .write (buffer , fake_audios , decoder_model .sampling_rate , format = req .format )
370379
371380 return StreamResponse (
372- iterable = [ buffer .getvalue ()] ,
381+ iterable = buffer_to_async_generator ( buffer .getvalue ()) ,
373382 headers = {
374383 "Content-Disposition" : f"attachment; filename=audio.{ req .format } " ,
375384 },
@@ -378,7 +387,7 @@ def api_invoke_model(
378387
379388
380389@routes .http .post ("/v1/health" )
381- def api_health ():
390+ async def api_health ():
382391 """
383392 Health check
384393 """
@@ -409,6 +418,7 @@ def parse_args():
409418 parser .add_argument ("--compile" , action = "store_true" )
410419 parser .add_argument ("--max-text-length" , type = int , default = 0 )
411420 parser .add_argument ("--listen" , type = str , default = "127.0.0.1:8000" )
421+ parser .add_argument ("--workers" , type = int , default = 1 )
412422
413423 return parser .parse_args ()
414424
@@ -433,7 +443,7 @@ def parse_args():
433443if __name__ == "__main__" :
434444 import threading
435445
436- from zibai import create_bind_socket , serve
446+ import uvicorn
437447
438448 args = parse_args ()
439449 args .precision = torch .half if args .half else torch .bfloat16
@@ -480,13 +490,5 @@ def parse_args():
480490 )
481491
482492 logger .info (f"Warming up done, starting server at http://{ args .listen } " )
483- sock = create_bind_socket (args .listen )
484- sock .listen ()
485-
486- # Start server
487- serve (
488- app = app ,
489- bind_sockets = [sock ],
490- max_workers = 10 ,
491- graceful_exit = threading .Event (),
492- )
493+ host , port = args .listen .split (":" )
494+ uvicorn .run (app , host = host , port = int (port ), workers = args .workers , log_level = "info" )
0 commit comments