@@ -280,40 +280,27 @@ async def step_async(
280
280
scheduler_outputs = cached_outputs .scheduler_outputs
281
281
allow_async_output_proc = cached_outputs .allow_async_output_proc
282
282
283
- # Detect async + multi-step
284
- use_async_and_multi_step = (self .scheduler_config .is_multi_step
285
- and allow_async_output_proc )
286
-
287
283
ctx = self .scheduler_contexts [virtual_engine ]
288
284
285
+ # Clear outputs for each new scheduler iteration
286
+ ctx .request_outputs .clear ()
287
+
289
288
# skip the scheduler if there are any remaining steps in the seq groups.
290
289
# This ensures that the scheduler is only called again when the current
291
290
# batch has completed.
292
291
if not self ._has_remaining_steps (seq_group_metadata_list ):
293
292
294
- # Clear outputs on scheduler iteration start
295
- ctx .request_outputs .clear ()
296
-
297
293
# Schedule iteration
298
294
(seq_group_metadata_list , scheduler_outputs ,
299
295
allow_async_output_proc
300
296
) = self .scheduler [virtual_engine ].schedule ()
301
297
302
- # Detect async + multi-step
303
- use_async_and_multi_step = (self .scheduler_config .is_multi_step
304
- and allow_async_output_proc )
298
+ ctx .seq_group_metadata_list = seq_group_metadata_list
299
+ ctx .scheduler_outputs = scheduler_outputs
305
300
306
301
# Maybe switch from async mode to sync mode
307
302
if not allow_async_output_proc and len (ctx .output_queue ) > 0 :
308
- self ._process_model_outputs (virtual_engine = virtual_engine ,
309
- is_async = True )
310
-
311
- # For async + multi-step, init the queue
312
- if use_async_and_multi_step :
313
- assert len (ctx .output_queue ) == 0
314
- assert seq_group_metadata_list is not None
315
- ctx .output_queue .append (
316
- (None , seq_group_metadata_list , scheduler_outputs ))
303
+ self ._process_model_outputs (ctx = ctx )
317
304
318
305
if (self .scheduler_config .is_multi_step
319
306
and scheduler_outputs .num_lookahead_slots > 0 ):
@@ -351,26 +338,20 @@ async def step_async(
351
338
last_sampled_token_ids = last_sampled_token_ids )
352
339
353
340
if allow_async_output_proc :
354
- async_callback = self .async_callback_multi_step [
355
- virtual_engine ] if use_async_and_multi_step \
356
- else self .async_callback [virtual_engine ]
357
-
358
- execute_model_req .async_callback = async_callback
359
- execute_model_req .use_async_and_multi_step = \
360
- use_async_and_multi_step
341
+ execute_model_req .async_callback = self .async_callbacks [
342
+ virtual_engine ]
361
343
362
344
# Execute the model.
363
345
output = await self .model_executor .execute_model_async (
364
346
execute_model_req )
347
+
365
348
# we need to do this here so that last step's sampled_token_ids can
366
349
# be passed to the next iteration for PP.
367
350
if self .scheduler_config .is_multi_step :
368
351
self ._update_cached_scheduler_output (virtual_engine , output )
369
352
else :
370
- if not use_async_and_multi_step and len (ctx .output_queue ) > 0 :
371
- assert not self .scheduler_config .is_multi_step
372
- self ._process_model_outputs (virtual_engine = virtual_engine ,
373
- is_async = True )
353
+ if len (ctx .output_queue ) > 0 :
354
+ self ._process_model_outputs (ctx = ctx )
374
355
output = []
375
356
376
357
# Finish the current step for all the sequence groups.
@@ -384,24 +365,22 @@ async def step_async(
384
365
self .cached_scheduler_outputs [
385
366
virtual_engine ] = SchedulerOutputState ()
386
367
387
- if use_async_and_multi_step :
388
- # For async + multi-step, clear the queue
389
- ctx .output_queue .clear ()
390
- else :
391
- ctx .output_queue .append (
392
- (output , seq_group_metadata_list , scheduler_outputs ))
368
+ is_async = allow_async_output_proc
369
+ is_last_step = True
370
+ ctx .output_queue .append (
371
+ (output , seq_group_metadata_list , scheduler_outputs , is_async ,
372
+ is_last_step ))
393
373
394
- if output and allow_async_output_proc :
395
- assert len (
396
- output
397
- ) == 1 , "Multi step decoding does not work with async output processing." # noqa: E501
398
- self ._advance_to_next_step (
399
- output [0 ], seq_group_metadata_list ,
400
- scheduler_outputs .scheduled_seq_groups )
374
+ if output and allow_async_output_proc :
375
+ assert len (
376
+ output
377
+ ) == 1 , "Async postprocessor expects only a single output set"
378
+ self ._advance_to_next_step (
379
+ output [0 ], seq_group_metadata_list ,
380
+ scheduler_outputs .scheduled_seq_groups )
401
381
402
382
if not allow_async_output_proc :
403
- self ._process_model_outputs (virtual_engine = virtual_engine ,
404
- is_async = False )
383
+ self ._process_model_outputs (ctx = ctx )
405
384
406
385
# Log stats.
407
386
self .do_log_stats (scheduler_outputs , output )
@@ -411,17 +390,12 @@ async def step_async(
411
390
412
391
else :
413
392
# Multi-step case
414
- if use_async_and_multi_step :
415
- return []
416
- else :
417
- ctx .request_outputs = []
393
+ return ctx .request_outputs
418
394
419
395
if not self .has_unfinished_requests ():
420
396
# Drain async postprocessor (if exists)
421
397
if len (ctx .output_queue ) > 0 :
422
- assert not self .scheduler_config .is_multi_step
423
- self ._process_model_outputs (virtual_engine = virtual_engine ,
424
- is_async = True )
398
+ self ._process_model_outputs (ctx = ctx )
425
399
assert len (ctx .output_queue ) == 0
426
400
427
401
return ctx .request_outputs
@@ -640,6 +614,17 @@ def __init__(self,
640
614
self .log_requests = log_requests
641
615
self .engine = self ._init_engine (* args , ** kwargs )
642
616
617
+ # This ensures quick processing of request outputs
618
+ # so the append to asyncio queues is not delayed,
619
+ # especially for multi-step.
620
+ #
621
+ # TODO: Currently, disabled for engine_use_ray, ask
622
+ # Cody/Will/Woosuk about this case.
623
+ self .use_process_request_outputs_callback = not self .engine_use_ray
624
+ if self .use_process_request_outputs_callback :
625
+ self .engine .process_request_outputs_callback = \
626
+ self .process_request_outputs
627
+
643
628
if self .engine_use_ray :
644
629
print_warning_once (
645
630
"DEPRECATED. `--engine-use-ray` is deprecated and will "
@@ -883,13 +868,27 @@ async def engine_step(self, virtual_engine: int) -> bool:
883
868
request_outputs = await self .engine .step_async (virtual_engine )
884
869
885
870
# Put the outputs into the corresponding streams.
886
- finished = True
871
+ # If used as a callback, then already invoked inside
872
+ # LLMEngine's _process_model_outputs
873
+ if not self .use_process_request_outputs_callback :
874
+ all_finished = self .process_request_outputs (request_outputs )
875
+ else :
876
+ # For callback case, we only need to detect when all
877
+ # requests are finished
878
+ all_finished = all (request_output .finished
879
+ for request_output in request_outputs )
880
+
881
+ return not all_finished
882
+
883
+ def process_request_outputs (self , request_outputs ) -> bool :
884
+ # Put the outputs into the corresponding streams.
885
+ all_finished = True
887
886
for request_output in request_outputs :
888
887
self ._request_tracker .process_request_output (
889
888
request_output , verbose = self .log_requests )
890
- finished = finished and request_output .finished
889
+ all_finished = all_finished and request_output .finished
891
890
892
- return not finished
891
+ return all_finished
893
892
894
893
async def _engine_abort (self , request_ids : Iterable [str ]):
895
894
if self .engine_use_ray :
0 commit comments