46
46
from llmserve .api import sdk
47
47
from llmserve .common .utils import _replace_prefix , _reverse_prefix
48
48
49
+ from starlette .responses import StreamingResponse
50
+ from typing import AsyncGenerator , Generator
51
+ from ray .serve .handle import DeploymentHandle , DeploymentResponseGenerator
52
+
49
53
# logger = get_logger(__name__)
50
54
logger = get_logger ("ray.serve" )
51
55
@@ -303,7 +307,6 @@ async def generate_text_batch(
303
307
def __repr__ (self ) -> str :
304
308
return f"{ self .__class__ .__name__ } :{ self .args .model_config .model_id } "
305
309
306
-
307
310
@serve .deployment (
308
311
# TODO make this configurable in llmserve run
309
312
autoscaling_config = {
@@ -315,12 +318,16 @@ def __repr__(self) -> str:
315
318
)
316
319
@serve .ingress (app )
317
320
class RouterDeployment :
318
- def __init__ (
319
- self , models : Dict [str , ClassNode ], model_configurations : Dict [str , Args ]
320
- ) -> None :
321
+ def __init__ (self , models : Dict [str , DeploymentHandle ], model_configurations : Dict [str , Args ]) -> None :
321
322
self ._models = models
322
323
# TODO: Remove this once it is possible to reconfigure models on the fly
323
324
self ._model_configurations = model_configurations
325
+ logger .info (f"init: _models.keys: { self ._models .keys ()} " )
326
+ # logger.info(f"init model_configurations: {model_configurations}")
327
+ for modelkey in self ._models .keys ():
328
+ if self ._model_configurations [modelkey ].model_config .stream :
329
+ logger .info (f"Set stream=true for { modelkey } " )
330
+ self ._models [modelkey ] = self ._models [modelkey ].options (stream = True )
324
331
325
332
@app .post ("/{model}/run/predict" )
326
333
async def predict (self , model : str , prompt : Union [Prompt , List [Prompt ]]) -> Union [Dict [str , Any ], List [Dict [str , Any ]], List [Any ]]:
@@ -364,6 +371,30 @@ async def metadata(self, model: str) -> Dict[str, Dict[str, Any]]:
364
371
async def models (self ) -> List [str ]:
365
372
return list (self ._models .keys ())
366
373
374
+ @app .post ("/run/stream" )
375
+ def streamer (self , data : dict ) -> StreamingResponse :
376
+ logger .info (f"data: { data } " )
377
+ logger .info (f'Got stream -> body: { data } , keys: { self ._models .keys ()} ' )
378
+ prompt = data .get ("prompt" )
379
+ model = data .get ("model" )
380
+ modelKeys = list (self ._models .keys ())
381
+ modelID = model
382
+ for item in modelKeys :
383
+ logger .info (f"_reverse_prefix(item): { _reverse_prefix (item )} " )
384
+ if _reverse_prefix (item ) == model :
385
+ modelID = item
386
+ logger .info (f"set stream model id: { item } " )
387
+ logger .info (f"search stream model key: { modelID } " )
388
+ return StreamingResponse (self .streamer_generate_text (modelID , prompt ), media_type = "text/plain" )
389
+
390
+ async def streamer_generate_text (self , modelID : str , prompt : str ) -> AsyncGenerator [str , None ]:
391
+ logger .info (f'streamer_generate_text: { modelID } , prompt: "{ prompt } "' )
392
+ r : DeploymentResponseGenerator = self ._models [modelID ].stream_generate_texts .remote (prompt )
393
+ async for i in r :
394
+ # logger.info(f"RouterDeployment.streamer_generate_text -> yield -> {type(i)}->{i}")
395
+ if not isinstance (i , str ):
396
+ continue
397
+ yield i
367
398
368
399
@serve .deployment (
369
400
# TODO make this configurable in llmserve run
0 commit comments