Skip to content

Commit 94a0658

Browse files
authored
fix generation_info efficiency (#1359)
1 parent 112b194 commit 94a0658

File tree

5 files changed

+52
-34
lines changed

5 files changed

+52
-34
lines changed

docs/source/LLM/VLLM推理加速与部署.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ history1 = resp_list[1]['history']
5959
request_list = [{'query': '这有什么好吃的', 'history': history1}]
6060
gen = inference_stream_vllm(llm_engine, template, request_list, generation_info=generation_info)
6161
query = request_list[0]['query']
62-
history1 = resp_list[1]['history']
6362
print_idx = 0
6463
print(f'query: {query}\nresponse: ', end='')
6564
for resp_list in gen:

docs/source/LLM/自定义与拓展.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# 默认使用modelscope的dataset_id, 同时也支持huggingface的dataset_id
2121
--dataset {dataset_id} {dataset_path} HF::{dataset_id}
2222

23-
# 数据集混合: 以下取dataset_id中subset1和subset2子数据集并采样20000条. 如果不使用`#{dataset_sample}`, 则使用数据集中的所有样本
23+
# 数据集混合: 以下取dataset_id中subset1和subset2子数据集并随机采样20000条. 如果不使用`#{dataset_sample}`, 则使用数据集中的所有样本
2424
--dataset {dataset_name}#20000 {dataset_id}:{subset1}/{subset2}#20000 {dataset_path}#10000
2525
```
2626

docs/source_en/LLM/Customization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The specified format for each dataset is as follows: `[HF or MS::]{dataset_name}
2222
# Defaulting to using the dataset_id from modelscope, while also supporting the dataset_id from huggingface.
2323
--dataset {dataset_id} {dataset_path} HF::{dataset_id}
2424

25-
# Dataset Mixing: the following command takes subset1 and subset2 from dataset_id and samples 20,000 records. If `#{dataset_sample}` is not used, all samples from the dataset will be used.
25+
# Dataset Mixing: the following command takes subset1 and subset2 from dataset_id and randomly samples 20,000 records. If `#{dataset_sample}` is not used, all samples from the dataset will be used.
2626
--dataset {dataset_name}#20000 {dataset_id}:{subset1}/{subset2}#20000 {dataset_path}#10000
2727
```
2828

docs/source_en/LLM/VLLM-inference-acceleration-and-deployment.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ history1 = resp_list[1]['history']
5656
request_list = [{'query': 'Is there anything tasty here?', 'history': history1}]
5757
gen = inference_stream_vllm(llm_engine, template, request_list, generation_info=generation_info)
5858
query = request_list[0]['query']
59-
history1 = resp_list[1]['history']
6059
print_idx = 0
6160
print(f'query: {query}\nresponse: ', end='')
6261
for resp_list in gen:

swift/llm/utils/vllm_utils.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,13 @@ def _prepare_vllm_request(llm_engine: LLMEngine,
278278
request_list: List[Dict[str, Any]],
279279
*,
280280
generation_config: VllmGenerationConfig,
281+
generation_info: Dict[str, Any],
281282
lora_request: Optional['LoRARequest'] = None,
282283
use_tqdm: bool = False,
283284
**kwargs) -> Tuple[List[Optional[Dict[str, Any]]], List[Tuple[bool, int]]]:
285+
for key in ['num_prompt_tokens', 'num_generated_tokens', 'num_samples']:
286+
generation_info[key] = 0
287+
284288
template.model = llm_engine
285289
tokenizer = template.tokenizer
286290
if tokenizer.eos_token is not None and tokenizer.eos_token not in generation_config.stop:
@@ -327,22 +331,25 @@ def _prepare_vllm_request(llm_engine: LLMEngine,
327331
# input_ids exceeds `max_length`. Please increase the value of `max_length`.
328332
resp_list[i] = {'response': '', 'history': history}
329333
continue
330-
334+
generation_info['num_prompt_tokens'] += len(inputs['input_ids'])
335+
generation_info['num_samples'] += 1
331336
_add_vllm_request(
332337
llm_engine, inputs, request_id=str(i), generation_config=generation_config, **add_request_kwargs)
333338
return resp_list, agent_state
334339

335340

336341
@torch.inference_mode()
337-
def inference_stream_vllm(llm_engine: LLMEngine,
338-
template: Template,
339-
request_list: List[Dict[str, Any]],
340-
*,
341-
generation_config: Optional[VllmGenerationConfig] = None,
342-
generation_info: Optional[Dict[str, Any]] = None,
343-
lora_request: Optional['LoRARequest'] = None,
344-
use_tqdm: bool = False,
345-
**kwargs) -> Iterator[List[Dict[str, Any]]]:
342+
def inference_stream_vllm(
343+
llm_engine: LLMEngine,
344+
template: Template,
345+
request_list: List[Dict[str, Any]],
346+
*,
347+
generation_config: Optional[VllmGenerationConfig] = None,
348+
generation_info: Optional[Dict[str, Any]] = None,
349+
lora_request: Optional['LoRARequest'] = None,
350+
use_tqdm: bool = False,
351+
flush_steps: Optional[int] = None, # Ensuring efficiency
352+
**kwargs) -> Iterator[List[Dict[str, Any]]]:
346353
"""
347354
request_list: e.g. [{'query': 'hello!'}].
348355
The keys that can be included are: 'query', 'history', 'system'.
@@ -356,34 +363,43 @@ def inference_stream_vllm(llm_engine: LLMEngine,
356363
assert isinstance(generation_config, VllmGenerationConfig)
357364
request_list = deepcopy(request_list)
358365
generation_config = deepcopy(generation_config)
366+
if generation_info is None:
367+
generation_info = {}
368+
else:
369+
generation_info.clear()
370+
359371
resp_list, agent_state = _prepare_vllm_request(
360372
llm_engine,
361373
template,
362374
request_list,
363375
generation_config=generation_config,
376+
generation_info=generation_info,
364377
lora_request=lora_request,
365378
use_tqdm=use_tqdm,
366379
**kwargs)
367380

368-
if generation_info is None:
369-
generation_info = {}
370-
else:
371-
generation_info.clear()
372-
373381
if generation_config.use_beam_search:
374382
error_msg = 'Streaming generation does not support beam search.'
375383
raise ValueError(error_msg)
376384

385+
n_finished = 0
386+
n_steps = 0
387+
if flush_steps is None:
388+
flush_steps = min(10, generation_info['num_samples'])
377389
print_idx_list = [[0] for _ in range(len(request_list))]
378-
prog_bar = tqdm(total=len(request_list), dynamic_ncols=True, disable=not use_tqdm)
390+
num_generated_tokens = [0] * len(request_list)
391+
prog_bar = tqdm(total=generation_info['num_samples'], dynamic_ncols=True, disable=not use_tqdm)
379392
while llm_engine.has_unfinished_requests():
380-
for key in ['num_prompt_tokens', 'num_generated_tokens']:
381-
generation_info[key] = 0
393+
is_flush = False
394+
n_steps += 1
382395
step_outputs = llm_engine.step()
383396
for output in step_outputs:
384397
i = int(output.request_id)
385398
request = request_list[i]
386399
generate_ids = output.outputs[0].token_ids
400+
if not output.finished and n_steps % flush_steps != 0:
401+
continue
402+
is_flush = True
387403
safe_response = template.generate_ids_to_response(
388404
generate_ids, output.finished, print_idx=print_idx_list[i])
389405
query = request['query']
@@ -394,14 +410,20 @@ def inference_stream_vllm(llm_engine: LLMEngine,
394410
history[-1] = [query, safe_response]
395411
else:
396412
history[-1][-1] = history[-1][-1][:agent_state[i][1]] + safe_response
397-
generation_info['num_prompt_tokens'] += len(output.prompt_token_ids)
398-
generation_info['num_generated_tokens'] += sum(len(_output.token_ids) for _output in output.outputs)
413+
414+
n_gen_tokens = sum(len(_output.token_ids) for _output in output.outputs)
415+
generation_info['num_generated_tokens'] += n_gen_tokens - num_generated_tokens[i]
416+
num_generated_tokens[i] = n_gen_tokens
417+
399418
resp_list[i] = {'response': safe_response, 'history': history}
400419
if output.finished:
420+
n_finished += 1
401421
prog_bar.update()
422+
if not is_flush:
423+
continue
402424
runtime = time.perf_counter() - start_runtime
403425
generation_info['runtime'] = runtime
404-
generation_info['samples/s'] = len(step_outputs) / runtime
426+
generation_info['samples/s'] = n_finished / runtime
405427
generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
406428
yield resp_list
407429
prog_bar.close()
@@ -433,26 +455,25 @@ def inference_vllm(llm_engine: LLMEngine,
433455
assert isinstance(generation_config, VllmGenerationConfig)
434456
request_list = deepcopy(request_list)
435457
generation_config = deepcopy(generation_config)
458+
if generation_info is None:
459+
generation_info = {}
460+
else:
461+
generation_info.clear()
462+
436463
resp_list, agent_state = _prepare_vllm_request(
437464
llm_engine,
438465
template,
439466
request_list,
440467
generation_config=generation_config,
468+
generation_info=generation_info,
441469
lora_request=lora_request,
442470
use_tqdm=use_tqdm,
443471
**kwargs)
444472

445-
if generation_info is None:
446-
generation_info = {}
447-
else:
448-
generation_info.clear()
449-
for key in ['num_prompt_tokens', 'num_generated_tokens']:
450-
generation_info[key] = 0
451-
452473
tokenizer = template.tokenizer
453474
if use_tqdm:
454475
assert verbose is False
455-
prog_bar = tqdm(total=len(request_list), dynamic_ncols=True, disable=not use_tqdm)
476+
prog_bar = tqdm(total=generation_info['num_samples'], dynamic_ncols=True, disable=not use_tqdm)
456477
outputs = []
457478
while llm_engine.has_unfinished_requests():
458479
step_outputs = llm_engine.step()
@@ -474,7 +495,6 @@ def inference_vllm(llm_engine: LLMEngine,
474495
else:
475496
history[-1][-1] = history[-1][-1] + response
476497

477-
generation_info['num_prompt_tokens'] += len(output.prompt_token_ids)
478498
generation_info['num_generated_tokens'] += sum(len(_output.token_ids) for _output in output.outputs)
479499
resp_list[i] = {'response': response, 'history': history}
480500
if verbose:

0 commit comments

Comments
 (0)