Skip to content

Commit f3d2847

Browse files
authored
add streaming support (#66)
1 parent 6013d57 commit f3d2847

File tree

8 files changed

+115
-8
lines changed

8 files changed

+115
-8
lines changed

llmserve/backend/llm/engines/_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from llmserve.backend.logger import get_logger
1313
from llmserve.backend.server.models import Args, Prompt
1414
import asyncio
15+
from typing import AsyncGenerator, Generator
1516

1617
logger = get_logger(__name__)
1718

@@ -63,4 +64,8 @@ async def predict(
6364

6465
@abstractmethod
6566
async def check_health(self):
67+
pass
68+
69+
@abstractmethod
70+
def stream_generate_texts(self, prompt: str) -> Generator[str, None, None]:
6671
pass

llmserve/backend/llm/engines/generic.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
from llmserve.backend.server.models import Args, LLMConfig, Prompt, Response
3232
from llmserve.backend.server.utils import render_gradio_params
3333
from ._base import LLMEngine
34+
35+
from typing import AsyncGenerator, Generator
36+
from queue import Empty
37+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
38+
from threading import Thread
39+
3440
logger = get_logger(__name__)
3541

3642
@timeit
@@ -278,6 +284,12 @@ def ping(self) -> bool:
278284
"""Ping the worker."""
279285
return True
280286

287+
async def worker_stream_generate_texts(self, prompt: str, **kwargs) -> Generator[str, None, None]: # type: ignore
288+
logger.info(f"Call PredictionWorker.worker_stream_generate_texts with kwargs: {kwargs}")
289+
for s in self.generator.streamGenerate(prompt, **kwargs):
290+
logger.info(f"PredictionWorker.worker_stream_generate_texts -> yield ->{s}")
291+
yield s
292+
281293
class GenericEngine(LLMEngine):
282294
base_worker_group = None
283295

@@ -416,4 +428,15 @@ async def check_health(self):
416428
raise RuntimeError(
417429
f"At least one prediction worker is dead. Dead workers: {dead_actors}. "
418430
"Reinitializing worker group."
419-
)
431+
)
432+
433+
def stream_generate_texts(self, prompt: str) -> Generator[str, None, None]: # type: ignore
434+
logger.info(f"GenericEngine.stream_generate_texts -> worker.length: {len(self.base_worker_group)}")
435+
worker0 = self.base_worker_group[0]
436+
for strHandle in worker0.worker_stream_generate_texts.remote(
437+
prompt,
438+
**self.args.model_config.generation.all_generate_kwargs if self.args.model_config.generation else {}
439+
):
440+
val = ray.get(strHandle)
441+
logger.info(f"GenericEngine.stream_generate_texts -> yield -> {val}")
442+
yield val

llmserve/backend/llm/pipelines/_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from .processors import StopOnTokens
2323
from .utils import tokenize_stopping_sequences_where_needed
2424

25+
from typing import AsyncGenerator, Generator
26+
2527
if TYPE_CHECKING:
2628
from ..initializers._base import LLMInitializer
2729

@@ -338,6 +340,9 @@ def _sanitize_parameters(
338340

339341
return preprocess_params, forward_params, postprocess_params
340342

343+
@abstractmethod
344+
def streamGenerate(self, prompt: str, **generate_kwargs) -> Generator[str, None, None]:
345+
pass
341346

342347
class StreamingPipeline(BasePipeline):
343348
def stream(

llmserve/backend/llm/pipelines/default_pipeline.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from .processors import StopOnTokens
1313
from .utils import construct_prompts, truncate_to_first_stop_token
1414

15+
from typing import AsyncGenerator, Generator
16+
import asyncio
17+
from transformers import TextIteratorStreamer
18+
from threading import Thread
19+
from queue import Empty
20+
1521
logger = get_logger(__name__)
1622

1723

@@ -160,3 +166,31 @@ def postprocess(self, model_outputs, **postprocess_kwargs) -> List[Response]:
160166
response.generation_time = model_outputs["generation_time"]
161167
response.postprocessing_time = et
162168
return decoded
169+
170+
def streamGenerate(self, prompt: str, **generate_kwargs) -> Generator[str, None, None]:
171+
logger.info(f"DefaultPipeline.streamGenerate with generate_kwargs: {generate_kwargs}")
172+
streamer = TextIteratorStreamer(self.tokenizer, timeout=0, skip_prompt=True, skip_special_tokens=True)
173+
input_ids = self.tokenizer([prompt], return_tensors="pt")
174+
# generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=20)
175+
max_new_tokens = 256
176+
if generate_kwargs["max_new_tokens"]:
177+
max_new_tokens = generate_kwargs["max_new_tokens"]
178+
generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_new_tokens)
179+
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
180+
thread.start()
181+
while True:
182+
try:
183+
for token in streamer:
184+
logger.info(f'DefaultPipeline.streamGenerate -> Yield -> "{token}" -> "{type(token)}"')
185+
yield token
186+
break
187+
except Empty:
188+
asyncio.sleep(0.001)
189+
190+
# start = 0
191+
# while True:
192+
# val = prompt + str(start)
193+
# logger.info(f"PredictionWorker.worker_stream_generate_texts -> yield -> {val}")
194+
# yield val
195+
# start += 1
196+
# asyncio.sleep(1)

llmserve/backend/llm/predictor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from llmserve.backend.logger import get_logger
1515
from llmserve.backend.server.models import Args, Prompt
1616

17+
from typing import AsyncGenerator, Generator
18+
1719
initialize_node_remote = ray.remote(initialize_node)
1820
logger = get_logger(__name__)
1921

@@ -59,7 +61,6 @@ async def rollover(self, scaling_config: ScalingConfig, pg_timeout_s: float = 60
5961
args = self.args
6062
)
6163

62-
6364
self.new_worker_group = await self._create_worker_group(
6465
scaling_config, pg_timeout_s=pg_timeout_s
6566
)
@@ -178,4 +179,10 @@ async def _predict_async(
178179

179180
# Called by Serve to check the replica's health.
180181
async def check_health(self):
181-
self.engine.check_health()
182+
self.engine.check_health()
183+
184+
async def stream_generate_texts(self, prompt: str) -> Generator[str, None, None]: # type: ignore
185+
logger.info(f"call LLMPredictor.stream_generate_texts")
186+
for s in self.engine.stream_generate_texts(prompt):
187+
logger.info(f"LLMPredictor.stream_generate_texts -> yield ->{s}")
188+
yield s

llmserve/backend/server/app.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646
from llmserve.api import sdk
4747
from llmserve.common.utils import _replace_prefix, _reverse_prefix
4848

49+
from starlette.responses import StreamingResponse
50+
from typing import AsyncGenerator, Generator
51+
from ray.serve.handle import DeploymentHandle, DeploymentResponseGenerator
52+
4953
# logger = get_logger(__name__)
5054
logger = get_logger("ray.serve")
5155

@@ -303,7 +307,6 @@ async def generate_text_batch(
303307
def __repr__(self) -> str:
304308
return f"{self.__class__.__name__}:{self.args.model_config.model_id}"
305309

306-
307310
@serve.deployment(
308311
# TODO make this configurable in llmserve run
309312
autoscaling_config={
@@ -315,12 +318,16 @@ def __repr__(self) -> str:
315318
)
316319
@serve.ingress(app)
317320
class RouterDeployment:
318-
def __init__(
319-
self, models: Dict[str, ClassNode], model_configurations: Dict[str, Args]
320-
) -> None:
321+
def __init__(self, models: Dict[str, DeploymentHandle], model_configurations: Dict[str, Args]) -> None:
321322
self._models = models
322323
# TODO: Remove this once it is possible to reconfigure models on the fly
323324
self._model_configurations = model_configurations
325+
logger.info(f"init: _models.keys: {self._models.keys()}")
326+
# logger.info(f"init model_configurations: {model_configurations}")
327+
for modelkey in self._models.keys():
328+
if self._model_configurations[modelkey].model_config.stream:
329+
logger.info(f"Set stream=true for {modelkey}")
330+
self._models[modelkey] = self._models[modelkey].options(stream=True)
324331

325332
@app.post("/{model}/run/predict")
326333
async def predict(self, model: str, prompt: Union[Prompt, List[Prompt]]) -> Union[Dict[str, Any], List[Dict[str, Any]], List[Any]]:
@@ -364,6 +371,30 @@ async def metadata(self, model: str) -> Dict[str, Dict[str, Any]]:
364371
async def models(self) -> List[str]:
365372
return list(self._models.keys())
366373

374+
@app.post("/run/stream")
375+
def streamer(self, data: dict) -> StreamingResponse:
376+
logger.info(f"data: {data}")
377+
logger.info(f'Got stream -> body: {data}, keys: {self._models.keys()}')
378+
prompt = data.get("prompt")
379+
model = data.get("model")
380+
modelKeys = list(self._models.keys())
381+
modelID = model
382+
for item in modelKeys:
383+
logger.info(f"_reverse_prefix(item): {_reverse_prefix(item)}")
384+
if _reverse_prefix(item) == model:
385+
modelID = item
386+
logger.info(f"set stream model id: {item}")
387+
logger.info(f"search stream model key: {modelID}")
388+
return StreamingResponse(self.streamer_generate_text(modelID, prompt), media_type="text/plain")
389+
390+
async def streamer_generate_text(self, modelID: str, prompt: str) -> AsyncGenerator[str, None]:
391+
logger.info(f'streamer_generate_text: {modelID}, prompt: "{prompt}"')
392+
r: DeploymentResponseGenerator = self._models[modelID].stream_generate_texts.remote(prompt)
393+
async for i in r:
394+
# logger.info(f"RouterDeployment.streamer_generate_text -> yield -> {type(i)}->{i}")
395+
if not isinstance(i, str):
396+
continue
397+
yield i
367398

368399
@serve.deployment(
369400
# TODO make this configurable in llmserve run

llmserve/backend/server/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def all_generate_kwargs(self) -> Dict[str, Any]:
377377

378378

379379
class LLMConfig(BaseModelExtended):
380+
stream: bool = False # enable steaming api
380381
warmup: bool # need warmup?
381382
model_task: str # need verification, TODO
382383
model_id: str

models/text-generation--facebook--opt-125m.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ deployment_config:
1212
ray_actor_options:
1313
num_cpus: 0.1 # for a model deployment, we have 3 actor created, 1 and 2 will cost 0.1 cpu, and the model infrence will cost 6(see the setting in the end of the file)
1414
model_config:
15+
stream: False
1516
warmup: True
1617
model_task: text-generation
1718
model_id: facebook/opt-125m
@@ -20,7 +21,7 @@ model_config:
2021
# s3_mirror_config:
2122
# endpoint_url: http://39.107.108.170:9000 # Optinal for custom S3 storage endpoint url
2223
# bucket_uri: s3://opt-125m/facemodel/ # Must include hash file with commit id in repo
23-
# bucket_uri: /tmp/hub/opt-125m/ # Local path of model with hash file
24+
# bucket_uri: /Users/hhwang/models/opt-125m/ # Local path of model with hash file
2425
# git_uri: https://portal.opencsg.com/models/opt-125m.git # git address for git clone
2526
initializer:
2627
type: SingleDevice

0 commit comments

Comments
 (0)