@@ -278,9 +278,13 @@ def _prepare_vllm_request(llm_engine: LLMEngine,
278278 request_list : List [Dict [str , Any ]],
279279 * ,
280280 generation_config : VllmGenerationConfig ,
281+ generation_info : Dict [str , Any ],
281282 lora_request : Optional ['LoRARequest' ] = None ,
282283 use_tqdm : bool = False ,
283284 ** kwargs ) -> Tuple [List [Optional [Dict [str , Any ]]], List [Tuple [bool , int ]]]:
285+ for key in ['num_prompt_tokens' , 'num_generated_tokens' , 'num_samples' ]:
286+ generation_info [key ] = 0
287+
284288 template .model = llm_engine
285289 tokenizer = template .tokenizer
286290 if tokenizer .eos_token is not None and tokenizer .eos_token not in generation_config .stop :
@@ -327,22 +331,25 @@ def _prepare_vllm_request(llm_engine: LLMEngine,
327331 # input_ids exceeds `max_length`. Please increase the value of `max_length`.
328332 resp_list [i ] = {'response' : '' , 'history' : history }
329333 continue
330-
334+ generation_info ['num_prompt_tokens' ] += len (inputs ['input_ids' ])
335+ generation_info ['num_samples' ] += 1
331336 _add_vllm_request (
332337 llm_engine , inputs , request_id = str (i ), generation_config = generation_config , ** add_request_kwargs )
333338 return resp_list , agent_state
334339
335340
336341@torch .inference_mode ()
337- def inference_stream_vllm (llm_engine : LLMEngine ,
338- template : Template ,
339- request_list : List [Dict [str , Any ]],
340- * ,
341- generation_config : Optional [VllmGenerationConfig ] = None ,
342- generation_info : Optional [Dict [str , Any ]] = None ,
343- lora_request : Optional ['LoRARequest' ] = None ,
344- use_tqdm : bool = False ,
345- ** kwargs ) -> Iterator [List [Dict [str , Any ]]]:
342+ def inference_stream_vllm (
343+ llm_engine : LLMEngine ,
344+ template : Template ,
345+ request_list : List [Dict [str , Any ]],
346+ * ,
347+ generation_config : Optional [VllmGenerationConfig ] = None ,
348+ generation_info : Optional [Dict [str , Any ]] = None ,
349+ lora_request : Optional ['LoRARequest' ] = None ,
350+ use_tqdm : bool = False ,
351+ flush_steps : Optional [int ] = None , # Ensuring efficiency
352+ ** kwargs ) -> Iterator [List [Dict [str , Any ]]]:
346353 """
347354 request_list: e.g. [{'query': 'hello!'}].
348355 The keys that can be included are: 'query', 'history', 'system'.
@@ -356,34 +363,43 @@ def inference_stream_vllm(llm_engine: LLMEngine,
356363 assert isinstance (generation_config , VllmGenerationConfig )
357364 request_list = deepcopy (request_list )
358365 generation_config = deepcopy (generation_config )
366+ if generation_info is None :
367+ generation_info = {}
368+ else :
369+ generation_info .clear ()
370+
359371 resp_list , agent_state = _prepare_vllm_request (
360372 llm_engine ,
361373 template ,
362374 request_list ,
363375 generation_config = generation_config ,
376+ generation_info = generation_info ,
364377 lora_request = lora_request ,
365378 use_tqdm = use_tqdm ,
366379 ** kwargs )
367380
368- if generation_info is None :
369- generation_info = {}
370- else :
371- generation_info .clear ()
372-
373381 if generation_config .use_beam_search :
374382 error_msg = 'Streaming generation does not support beam search.'
375383 raise ValueError (error_msg )
376384
385+ n_finished = 0
386+ n_steps = 0
387+ if flush_steps is None :
388+ flush_steps = min (10 , generation_info ['num_samples' ])
377389 print_idx_list = [[0 ] for _ in range (len (request_list ))]
378- prog_bar = tqdm (total = len (request_list ), dynamic_ncols = True , disable = not use_tqdm )
390+ num_generated_tokens = [0 ] * len (request_list )
391+ prog_bar = tqdm (total = generation_info ['num_samples' ], dynamic_ncols = True , disable = not use_tqdm )
379392 while llm_engine .has_unfinished_requests ():
380- for key in [ 'num_prompt_tokens' , 'num_generated_tokens' ]:
381- generation_info [ key ] = 0
393+ is_flush = False
394+ n_steps += 1
382395 step_outputs = llm_engine .step ()
383396 for output in step_outputs :
384397 i = int (output .request_id )
385398 request = request_list [i ]
386399 generate_ids = output .outputs [0 ].token_ids
400+ if not output .finished and n_steps % flush_steps != 0 :
401+ continue
402+ is_flush = True
387403 safe_response = template .generate_ids_to_response (
388404 generate_ids , output .finished , print_idx = print_idx_list [i ])
389405 query = request ['query' ]
@@ -394,14 +410,20 @@ def inference_stream_vllm(llm_engine: LLMEngine,
394410 history [- 1 ] = [query , safe_response ]
395411 else :
396412 history [- 1 ][- 1 ] = history [- 1 ][- 1 ][:agent_state [i ][1 ]] + safe_response
397- generation_info ['num_prompt_tokens' ] += len (output .prompt_token_ids )
398- generation_info ['num_generated_tokens' ] += sum (len (_output .token_ids ) for _output in output .outputs )
413+
414+ n_gen_tokens = sum (len (_output .token_ids ) for _output in output .outputs )
415+ generation_info ['num_generated_tokens' ] += n_gen_tokens - num_generated_tokens [i ]
416+ num_generated_tokens [i ] = n_gen_tokens
417+
399418 resp_list [i ] = {'response' : safe_response , 'history' : history }
400419 if output .finished :
420+ n_finished += 1
401421 prog_bar .update ()
422+ if not is_flush :
423+ continue
402424 runtime = time .perf_counter () - start_runtime
403425 generation_info ['runtime' ] = runtime
404- generation_info ['samples/s' ] = len ( step_outputs ) / runtime
426+ generation_info ['samples/s' ] = n_finished / runtime
405427 generation_info ['tokens/s' ] = generation_info ['num_generated_tokens' ] / runtime
406428 yield resp_list
407429 prog_bar .close ()
@@ -433,26 +455,25 @@ def inference_vllm(llm_engine: LLMEngine,
433455 assert isinstance (generation_config , VllmGenerationConfig )
434456 request_list = deepcopy (request_list )
435457 generation_config = deepcopy (generation_config )
458+ if generation_info is None :
459+ generation_info = {}
460+ else :
461+ generation_info .clear ()
462+
436463 resp_list , agent_state = _prepare_vllm_request (
437464 llm_engine ,
438465 template ,
439466 request_list ,
440467 generation_config = generation_config ,
468+ generation_info = generation_info ,
441469 lora_request = lora_request ,
442470 use_tqdm = use_tqdm ,
443471 ** kwargs )
444472
445- if generation_info is None :
446- generation_info = {}
447- else :
448- generation_info .clear ()
449- for key in ['num_prompt_tokens' , 'num_generated_tokens' ]:
450- generation_info [key ] = 0
451-
452473 tokenizer = template .tokenizer
453474 if use_tqdm :
454475 assert verbose is False
455- prog_bar = tqdm (total = len ( request_list ) , dynamic_ncols = True , disable = not use_tqdm )
476+ prog_bar = tqdm (total = generation_info [ 'num_samples' ] , dynamic_ncols = True , disable = not use_tqdm )
456477 outputs = []
457478 while llm_engine .has_unfinished_requests ():
458479 step_outputs = llm_engine .step ()
@@ -474,7 +495,6 @@ def inference_vllm(llm_engine: LLMEngine,
474495 else :
475496 history [- 1 ][- 1 ] = history [- 1 ][- 1 ] + response
476497
477- generation_info ['num_prompt_tokens' ] += len (output .prompt_token_ids )
478498 generation_info ['num_generated_tokens' ] += sum (len (_output .token_ids ) for _output in output .outputs )
479499 resp_list [i ] = {'response' : response , 'history' : history }
480500 if verbose :
0 commit comments