Skip to content

Commit b463549

Browse files
refactor streaming (#82)
1 parent 310de8a commit b463549

File tree

9 files changed

+588
-203
lines changed

9 files changed

+588
-203
lines changed

llmserve/backend/llm/engines/_base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from llmserve.backend.logger import get_logger
88

9-
from typing import List, Optional
9+
from typing import List, Optional, Iterator
1010
from ray.air import ScalingConfig
1111

1212
from llmserve.backend.logger import get_logger
13-
from llmserve.backend.server.models import Args, Prompt
13+
from llmserve.backend.server.models import Args, Prompt, Response
1414
import asyncio
1515
from typing import Union, AsyncGenerator, Generator
1616

@@ -67,5 +67,12 @@ async def check_health(self):
6767
pass
6868

6969
@abstractmethod
70-
def stream_generate_texts(self, prompt: Union[Prompt, List[Prompt]]) -> Generator[str, None, None]:
70+
async def stream(
71+
self,
72+
prompts: List[Prompt],
73+
*,
74+
timeout_s: float = 60,
75+
start_timestamp: Optional[float] = None,
76+
lock: asyncio.Lock,
77+
) -> Iterator[List[Response]]:
7178
pass

llmserve/backend/llm/engines/generic.py

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import gc
1111
import os
1212
import traceback
13-
from typing import List, Optional
13+
from typing import List, Optional, Iterator
1414

1515
import ray
1616
import ray.util
@@ -21,7 +21,7 @@
2121

2222
from llmserve.backend.llm.initializers import get_initializer_cls_by_name
2323
from llmserve.backend.llm.pipelines import get_pipeline_cls_by_name
24-
from llmserve.backend.llm.pipelines._base import BasePipeline
24+
from llmserve.backend.llm.pipelines._base import BasePipeline, StreamingPipeline
2525
from llmserve.backend.llm.utils import (
2626
init_torch_dist_process_group_async,
2727
timeit,
@@ -171,7 +171,26 @@ def generate(
171171
)
172172
return outputs
173173

174-
import logging
174+
@timeit
175+
def stream(
176+
prompts: List[Prompt],
177+
pipeline: BasePipeline,
178+
**generate_kwargs,
179+
) -> Iterator[List[Response]]:
180+
"""Generate predictions using a Pipeline.
181+
182+
Args:
183+
prompts (List[Prompt]): List of prompts.
184+
pipeline (BasePipeline): Pipeline to use.
185+
**generate_kwargs: Keyword arguments to pass to the pipeline's `generate` method.
186+
"""
187+
if not isinstance(pipeline, StreamingPipeline):
188+
raise RuntimeError(f"Pipeline {pipeline} does not support streaming.")
189+
yield from pipeline.stream(
190+
prompts,
191+
**generate_kwargs,
192+
)
193+
175194
@ray.remote
176195
class PredictionWorker(TorchDistributedWorker):
177196
"""A PredictionWorker is a Ray remote actor that runs a single shard of a DeepSpeed job.
@@ -277,21 +296,36 @@ def generate(
277296
)
278297
return responses_1 + responses_2
279298

299+
def stream(
300+
self,
301+
data: List[Prompt],
302+
*,
303+
timeout_s: Optional[float] = None,
304+
start_timestamp: Optional[float] = None,
305+
**kwargs,
306+
) -> Iterator[List[Response]]:
307+
yield from stream(
308+
data,
309+
self.generator,
310+
timeout_s=timeout_s,
311+
start_timestamp=start_timestamp,
312+
**kwargs,
313+
)
314+
280315
def __repr__(self) -> str:
281316
return f"{self.__class__.__name__}:{self.llm_config.model_id}"
282317

283318
def ping(self) -> bool:
284319
"""Ping the worker."""
285320
return True
286321

287-
async def worker_stream_generate_texts(self, prompt: Union[Prompt, List[Prompt]], **kwargs) -> Generator[str, None, None]: # type: ignore
288-
logger.info(f"Call PredictionWorker.worker_stream_generate_texts with kwargs: {kwargs}")
289-
for s in self.generator.streamGenerate(prompt, **kwargs):
290-
# logger.info(f"PredictionWorker.worker_stream_generate_texts -> yield ->{s}")
291-
yield s
322+
def can_stream(self) -> bool:
323+
"""Whether the worker can stream."""
324+
return isinstance(self.generator, StreamingPipeline)
292325

293326
class GenericEngine(LLMEngine):
294327
base_worker_group = None
328+
can_stream = None
295329

296330
async def launch_engine(
297331
self,
@@ -338,11 +372,11 @@ async def launch_engine(
338372
num_gpus_per_worker=scaling_config.num_gpus_per_worker
339373
)
340374
for worker, local_rank in zip(worker_group, local_ranks)
341-
# for worker in worker_group
342375
]
343376
)
344377

345378
self.base_worker_group = worker_group
379+
self.can_stream = await asyncio.gather(*[worker_group[0].can_stream.remote()])
346380
return worker_group
347381

348382
async def predict(
@@ -429,14 +463,45 @@ async def check_health(self):
429463
f"At least one prediction worker is dead. Dead workers: {dead_actors}. "
430464
"Reinitializing worker group."
431465
)
432-
433-
def stream_generate_texts(self, prompt: Union[Prompt, List[Prompt]]) -> Generator[str, None, None]: # type: ignore
434-
logger.info(f"GenericEngine.stream_generate_texts -> worker.length: {len(self.base_worker_group)}")
435-
worker0 = self.base_worker_group[0]
436-
for strHandle in worker0.worker_stream_generate_texts.remote(
437-
prompt,
438-
**self.args.model_config.generation.all_generate_kwargs if self.args.model_config.generation else {}
439-
):
440-
val = ray.get(strHandle)
441-
logger.info(f"GenericEngine.stream_generate_texts -> yield -> {val}")
442-
yield val
466+
467+
async def stream(
468+
self,
469+
prompts: List[Prompt],
470+
*,
471+
timeout_s: float = 60,
472+
start_timestamp: Optional[float] = None,
473+
lock: asyncio.Lock,
474+
) -> Iterator[List[Response]]:
475+
"""Generate text for a list of prompts.
476+
477+
Args:
478+
prompts (List[Prompt]): Batch of prompts to generate text from.
479+
timeout_s (float, optional): Timeout for the generation. Defaults
480+
to 60. Ignored if start_timestamp is None.
481+
start_timestamp (Optional[float], optional): Timestamp of when the
482+
batch was created. Defaults to None. If set, will early stop
483+
the generation.
484+
485+
Returns:
486+
A list of generated texts.
487+
"""
488+
if self.can_stream:
489+
async with lock:
490+
tasks = [
491+
worker.stream.options(num_returns="streaming").remote(
492+
prompts,
493+
timeout_s=timeout_s,
494+
start_timestamp=start_timestamp,
495+
**self.args.model_config.generation.all_generate_kwargs,
496+
)
497+
for worker in self.base_worker_group
498+
]
499+
async for result in tasks[0]:
500+
yield await result
501+
else:
502+
logger.warning(
503+
f"Pipeline {self.args.model_config.initialization.pipeline} does not support streaming. Ignoring queue."
504+
)
505+
yield await self.predict(
506+
prompts, timeout_s=timeout_s, start_timestamp=start_timestamp
507+
)

llmserve/backend/llm/engines/vllm/vllm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import torch
55
import gc
6-
from typing import List, Optional, Any, Dict, List, Optional, AsyncIterator
6+
from typing import List, Optional, Any, Dict, List, Optional, AsyncIterator, Iterator
77
from ray.air import ScalingConfig
88
from ray.util.placement_group import PlacementGroup
99
from llmserve.backend.server.models import Args, Prompt, Response
@@ -225,4 +225,14 @@ async def predict(
225225
return responses
226226

227227
async def check_health(self):
228-
logger.info("not implements yet...")
228+
logger.info("not implements yet...")
229+
230+
async def stream(
231+
self,
232+
prompts: List[Prompt],
233+
*,
234+
timeout_s: float = 60,
235+
start_timestamp: Optional[float] = None,
236+
lock: asyncio.Lock,
237+
) -> Iterator[List[Response]]:
238+
pass

llmserve/backend/llm/pipelines/_base.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,15 +340,10 @@ def _sanitize_parameters(
340340

341341
return preprocess_params, forward_params, postprocess_params
342342

343-
@abstractmethod
344-
def streamGenerate(self, prompt: Union[Prompt, List[Prompt]], **generate_kwargs) -> Generator[str, None, None]:
345-
pass
346-
347343
class StreamingPipeline(BasePipeline):
348344
def stream(
349345
self,
350-
inputs: List[str],
351-
queue: Queue,
346+
inputs: List[Union[str, Prompt]],
352347
**kwargs,
353348
) -> Iterator[List[Response]]:
354349
raise NotImplementedError()

0 commit comments

Comments
 (0)