10
10
import gc
11
11
import os
12
12
import traceback
13
- from typing import List , Optional
13
+ from typing import List , Optional , Iterator
14
14
15
15
import ray
16
16
import ray .util
21
21
22
22
from llmserve .backend .llm .initializers import get_initializer_cls_by_name
23
23
from llmserve .backend .llm .pipelines import get_pipeline_cls_by_name
24
- from llmserve .backend .llm .pipelines ._base import BasePipeline
24
+ from llmserve .backend .llm .pipelines ._base import BasePipeline , StreamingPipeline
25
25
from llmserve .backend .llm .utils import (
26
26
init_torch_dist_process_group_async ,
27
27
timeit ,
@@ -171,7 +171,26 @@ def generate(
171
171
)
172
172
return outputs
173
173
174
- import logging
174
+ @timeit
175
+ def stream (
176
+ prompts : List [Prompt ],
177
+ pipeline : BasePipeline ,
178
+ ** generate_kwargs ,
179
+ ) -> Iterator [List [Response ]]:
180
+ """Generate predictions using a Pipeline.
181
+
182
+ Args:
183
+ prompts (List[Prompt]): List of prompts.
184
+ pipeline (BasePipeline): Pipeline to use.
185
+ **generate_kwargs: Keyword arguments to pass to the pipeline's `generate` method.
186
+ """
187
+ if not isinstance (pipeline , StreamingPipeline ):
188
+ raise RuntimeError (f"Pipeline { pipeline } does not support streaming." )
189
+ yield from pipeline .stream (
190
+ prompts ,
191
+ ** generate_kwargs ,
192
+ )
193
+
175
194
@ray .remote
176
195
class PredictionWorker (TorchDistributedWorker ):
177
196
"""A PredictionWorker is a Ray remote actor that runs a single shard of a DeepSpeed job.
@@ -277,21 +296,36 @@ def generate(
277
296
)
278
297
return responses_1 + responses_2
279
298
299
+ def stream (
300
+ self ,
301
+ data : List [Prompt ],
302
+ * ,
303
+ timeout_s : Optional [float ] = None ,
304
+ start_timestamp : Optional [float ] = None ,
305
+ ** kwargs ,
306
+ ) -> Iterator [List [Response ]]:
307
+ yield from stream (
308
+ data ,
309
+ self .generator ,
310
+ timeout_s = timeout_s ,
311
+ start_timestamp = start_timestamp ,
312
+ ** kwargs ,
313
+ )
314
+
280
315
def __repr__ (self ) -> str :
281
316
return f"{ self .__class__ .__name__ } :{ self .llm_config .model_id } "
282
317
283
318
def ping (self ) -> bool :
284
319
"""Ping the worker."""
285
320
return True
286
321
287
- async def worker_stream_generate_texts (self , prompt : Union [Prompt , List [Prompt ]], ** kwargs ) -> Generator [str , None , None ]: # type: ignore
288
- logger .info (f"Call PredictionWorker.worker_stream_generate_texts with kwargs: { kwargs } " )
289
- for s in self .generator .streamGenerate (prompt , ** kwargs ):
290
- # logger.info(f"PredictionWorker.worker_stream_generate_texts -> yield ->{s}")
291
- yield s
322
+ def can_stream (self ) -> bool :
323
+ """Whether the worker can stream."""
324
+ return isinstance (self .generator , StreamingPipeline )
292
325
293
326
class GenericEngine (LLMEngine ):
294
327
base_worker_group = None
328
+ can_stream = None
295
329
296
330
async def launch_engine (
297
331
self ,
@@ -338,11 +372,11 @@ async def launch_engine(
338
372
num_gpus_per_worker = scaling_config .num_gpus_per_worker
339
373
)
340
374
for worker , local_rank in zip (worker_group , local_ranks )
341
- # for worker in worker_group
342
375
]
343
376
)
344
377
345
378
self .base_worker_group = worker_group
379
+ self .can_stream = await asyncio .gather (* [worker_group [0 ].can_stream .remote ()])
346
380
return worker_group
347
381
348
382
async def predict (
@@ -429,14 +463,45 @@ async def check_health(self):
429
463
f"At least one prediction worker is dead. Dead workers: { dead_actors } . "
430
464
"Reinitializing worker group."
431
465
)
432
-
433
- def stream_generate_texts (self , prompt : Union [Prompt , List [Prompt ]]) -> Generator [str , None , None ]: # type: ignore
434
- logger .info (f"GenericEngine.stream_generate_texts -> worker.length: { len (self .base_worker_group )} " )
435
- worker0 = self .base_worker_group [0 ]
436
- for strHandle in worker0 .worker_stream_generate_texts .remote (
437
- prompt ,
438
- ** self .args .model_config .generation .all_generate_kwargs if self .args .model_config .generation else {}
439
- ):
440
- val = ray .get (strHandle )
441
- logger .info (f"GenericEngine.stream_generate_texts -> yield -> { val } " )
442
- yield val
466
+
467
+ async def stream (
468
+ self ,
469
+ prompts : List [Prompt ],
470
+ * ,
471
+ timeout_s : float = 60 ,
472
+ start_timestamp : Optional [float ] = None ,
473
+ lock : asyncio .Lock ,
474
+ ) -> Iterator [List [Response ]]:
475
+ """Generate text for a list of prompts.
476
+
477
+ Args:
478
+ prompts (List[Prompt]): Batch of prompts to generate text from.
479
+ timeout_s (float, optional): Timeout for the generation. Defaults
480
+ to 60. Ignored if start_timestamp is None.
481
+ start_timestamp (Optional[float], optional): Timestamp of when the
482
+ batch was created. Defaults to None. If set, will early stop
483
+ the generation.
484
+
485
+ Returns:
486
+ A list of generated texts.
487
+ """
488
+ if self .can_stream :
489
+ async with lock :
490
+ tasks = [
491
+ worker .stream .options (num_returns = "streaming" ).remote (
492
+ prompts ,
493
+ timeout_s = timeout_s ,
494
+ start_timestamp = start_timestamp ,
495
+ ** self .args .model_config .generation .all_generate_kwargs ,
496
+ )
497
+ for worker in self .base_worker_group
498
+ ]
499
+ async for result in tasks [0 ]:
500
+ yield await result
501
+ else :
502
+ logger .warning (
503
+ f"Pipeline { self .args .model_config .initialization .pipeline } does not support streaming. Ignoring queue."
504
+ )
505
+ yield await self .predict (
506
+ prompts , timeout_s = timeout_s , start_timestamp = start_timestamp
507
+ )
0 commit comments