Skip to content

Commit e02ebfd

Browse files
authored
add "enable_prefix_caching" args for vllm engine. (#2939)
1 parent 5c73d6c commit e02ebfd

File tree

4 files changed

+51
-37
lines changed

4 files changed

+51
-37
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
277277
- enforce_eager: vllm使用pytorch eager模式还是建立cuda graph. 默认为`False`. 设置为True可以节约显存, 但会影响效率.
278278
- 🔥limit_mm_per_prompt: 控制vllm使用多图, 默认为`None`. 例如传入`--limit_mm_per_prompt '{"image": 10, "video": 5}'`
279279
- vllm_max_lora_rank: 默认为`16`. vllm对于lora支持的参数
280+
- enable_prefix_caching: 是否开启 vllm 的 Prefix Caching 能力. 默认为`False`. 设置为 True 可以节约重复请求前缀(例如 System Prompt,长文档或多轮对话)处理时间。
280281

281282

282283
### 合并参数

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ Parameter meanings can be found in the [vllm documentation](https://docs.vllm.ai
282282
- enforce_eager: Whether vllm uses pytorch eager mode or establishes a cuda graph. Default is `False`. Setting to True can save memory but may affect efficiency.
283283
- 🔥limit_mm_per_prompt: Controls vllm using multiple images, default is `None`. For example, use `--limit_mm_per_prompt '{"image": 10, "video": 5}'`.
284284
- vllm_max_lora_rank: Default value is `16`. Parameters supported by vllm for LoRA.
285+
- enable_prefix_caching: Whether enable `Automatic Prefix Caching` feature for vllm. Default is `False`. Setting to True can save processing time for repeatable request prefix(such as system prompt, long docs, or multi-turn dialog, etc).
285286

286287
### Merge Arguments
287288

swift/llm/argument/infer_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class VllmArguments:
6161
enforce_eager (bool): Flag to enforce eager execution. Default is False.
6262
limit_mm_per_prompt (Optional[str]): Limit multimedia per prompt. Default is None.
6363
vllm_max_lora_rank (int): Maximum LoRA rank. Default is 16.
64+
enable_prefix_caching (bool): Flag to enable automatic prefix caching. Default is False.
6465
"""
6566
# vllm
6667
gpu_memory_utilization: float = 0.9
@@ -72,6 +73,7 @@ class VllmArguments:
7273
enforce_eager: bool = False
7374
limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 10, "video": 5}'
7475
vllm_max_lora_rank: int = 16
76+
enable_prefix_caching: bool = False
7577

7678
def __post_init__(self):
7779
self.limit_mm_per_prompt = ModelArguments.parse_to_dict(self.limit_mm_per_prompt)
@@ -92,6 +94,7 @@ def get_vllm_engine_kwargs(self):
9294
'max_lora_rank': self.vllm_max_lora_rank,
9395
'enable_lora': len(adapters) > 0,
9496
'max_loras': max(len(adapters), 1),
97+
'enable_prefix_caching': self.enable_prefix_caching,
9598
}
9699

97100

swift/llm/infer/infer_engine/vllm_engine.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,30 @@
3434
class VllmEngine(InferEngine):
3535

3636
def __init__(
37-
self,
38-
model_id_or_path: str,
39-
torch_dtype: Optional[torch.dtype] = None,
40-
*,
41-
model_type: Optional[str] = None,
42-
use_hf: Optional[bool] = None,
43-
hub_token: Optional[str] = None,
44-
revision: Optional[str] = None,
45-
# engine_kwargs
46-
gpu_memory_utilization: float = 0.9,
47-
tensor_parallel_size: int = 1,
48-
pipeline_parallel_size: int = 1,
49-
max_model_len: Optional[int] = None,
50-
max_num_seqs: int = 256,
51-
disable_custom_all_reduce: bool = False,
52-
enforce_eager: bool = False,
53-
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
54-
# lora
55-
enable_lora: bool = False,
56-
max_loras: int = 1,
57-
max_lora_rank: int = 16,
58-
engine_kwargs: Optional[Dict[str, Any]] = None) -> None:
37+
self,
38+
model_id_or_path: str,
39+
torch_dtype: Optional[torch.dtype] = None,
40+
*,
41+
model_type: Optional[str] = None,
42+
use_hf: Optional[bool] = None,
43+
hub_token: Optional[str] = None,
44+
revision: Optional[str] = None,
45+
# engine_kwargs
46+
gpu_memory_utilization: float = 0.9,
47+
tensor_parallel_size: int = 1,
48+
pipeline_parallel_size: int = 1,
49+
max_model_len: Optional[int] = None,
50+
max_num_seqs: int = 256,
51+
disable_custom_all_reduce: bool = False,
52+
enforce_eager: bool = False,
53+
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
54+
# lora
55+
enable_lora: bool = False,
56+
max_loras: int = 1,
57+
max_lora_rank: int = 16,
58+
enable_prefix_caching: bool = False,
59+
engine_kwargs: Optional[Dict[str, Any]] = None,
60+
) -> None:
5961
self.processor = get_model_tokenizer(
6062
model_id_or_path,
6163
torch_dtype,
@@ -79,7 +81,9 @@ def __init__(
7981
enable_lora=enable_lora,
8082
max_loras=max_loras,
8183
max_lora_rank=max_lora_rank,
82-
engine_kwargs=engine_kwargs)
84+
enable_prefix_caching=enable_prefix_caching,
85+
engine_kwargs=engine_kwargs,
86+
)
8387

8488
self._prepare_engine()
8589
self._load_generation_config()
@@ -91,19 +95,22 @@ def _prepare_engine(self) -> None:
9195
engine = AsyncLLMEngine.from_engine_args(self.engine_args)
9296
self.engine = engine
9397

94-
def _prepare_engine_kwargs(self,
95-
gpu_memory_utilization: float = 0.9,
96-
tensor_parallel_size: int = 1,
97-
pipeline_parallel_size: int = 1,
98-
max_model_len: Optional[int] = None,
99-
max_num_seqs: int = 256,
100-
disable_custom_all_reduce: bool = False,
101-
enforce_eager: bool = False,
102-
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
103-
enable_lora: bool = False,
104-
max_loras: int = 1,
105-
max_lora_rank: int = 16,
106-
engine_kwargs: Optional[Dict[str, Any]] = None) -> None:
98+
def _prepare_engine_kwargs(
99+
self,
100+
gpu_memory_utilization: float = 0.9,
101+
tensor_parallel_size: int = 1,
102+
pipeline_parallel_size: int = 1,
103+
max_model_len: Optional[int] = None,
104+
max_num_seqs: int = 256,
105+
disable_custom_all_reduce: bool = False,
106+
enforce_eager: bool = False,
107+
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
108+
enable_lora: bool = False,
109+
max_loras: int = 1,
110+
max_lora_rank: int = 16,
111+
enable_prefix_caching: bool = False,
112+
engine_kwargs: Optional[Dict[str, Any]] = None,
113+
) -> None:
107114
if engine_kwargs is None:
108115
engine_kwargs = {}
109116
disable_log_stats = engine_kwargs.pop('disable_log_stats', True)
@@ -136,7 +143,9 @@ def _prepare_engine_kwargs(self,
136143
disable_custom_all_reduce=disable_custom_all_reduce,
137144
enforce_eager=enforce_eager,
138145
trust_remote_code=True,
139-
**engine_kwargs)
146+
enable_prefix_caching=enable_prefix_caching,
147+
**engine_kwargs,
148+
)
140149
self.engine_args = engine_args
141150
self.enable_lora = enable_lora
142151
if max_model_len is not None:

0 commit comments

Comments
 (0)