|
18 | 18 | create_sort_beams_key_function, |
19 | 19 | ) |
20 | 20 | from vllm.config import ( |
| 21 | + AttentionConfig, |
21 | 22 | CompilationConfig, |
22 | 23 | PoolerConfig, |
23 | 24 | ProfilerConfig, |
@@ -175,6 +176,10 @@ class LLM: |
175 | 176 | compilation_config: Either an integer or a dictionary. If it is an |
176 | 177 | integer, it is used as the mode of compilation optimization. If it |
177 | 178 | is a dictionary, it can specify the full compilation configuration. |
| 179 | + attention_config: Configuration for attention mechanisms. Can be a |
| 180 | + dictionary or an AttentionConfig instance. If a dictionary, it will |
| 181 | + be converted to an AttentionConfig. Allows specifying the attention |
| 182 | + backend and other attention-related settings. |
178 | 183 | **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. |
179 | 184 |
|
180 | 185 | Note: |
@@ -213,6 +218,7 @@ def __init__( |
213 | 218 | | StructuredOutputsConfig |
214 | 219 | | None = None, |
215 | 220 | profiler_config: dict[str, Any] | ProfilerConfig | None = None, |
| 221 | + attention_config: dict[str, Any] | AttentionConfig | None = None, |
216 | 222 | kv_cache_memory_bytes: int | None = None, |
217 | 223 | compilation_config: int | dict[str, Any] | CompilationConfig | None = None, |
218 | 224 | logits_processors: list[str | type[LogitsProcessor]] | None = None, |
@@ -252,51 +258,28 @@ def __init__( |
252 | 258 | if hf_overrides is None: |
253 | 259 | hf_overrides = {} |
254 | 260 |
|
255 | | - if compilation_config is not None: |
256 | | - if isinstance(compilation_config, int): |
257 | | - compilation_config_instance = CompilationConfig( |
258 | | - mode=CompilationMode(compilation_config) |
259 | | - ) |
260 | | - elif isinstance(compilation_config, dict): |
261 | | - compilation_config_instance = CompilationConfig( |
262 | | - **{ |
263 | | - k: v |
264 | | - for k, v in compilation_config.items() |
265 | | - if is_init_field(CompilationConfig, k) |
266 | | - } |
267 | | - ) |
268 | | - else: |
269 | | - compilation_config_instance = compilation_config |
270 | | - else: |
271 | | - compilation_config_instance = CompilationConfig() |
272 | | - |
273 | | - if structured_outputs_config is not None: |
274 | | - if isinstance(structured_outputs_config, dict): |
275 | | - structured_outputs_instance = StructuredOutputsConfig( |
276 | | - **{ |
277 | | - k: v |
278 | | - for k, v in structured_outputs_config.items() |
279 | | - if is_init_field(StructuredOutputsConfig, k) |
280 | | - } |
281 | | - ) |
282 | | - else: |
283 | | - structured_outputs_instance = structured_outputs_config |
284 | | - else: |
285 | | - structured_outputs_instance = StructuredOutputsConfig() |
286 | | - |
287 | | - if profiler_config is not None: |
288 | | - if isinstance(profiler_config, dict): |
289 | | - profiler_config_instance = ProfilerConfig( |
290 | | - **{ |
291 | | - k: v |
292 | | - for k, v in profiler_config.items() |
293 | | - if is_init_field(ProfilerConfig, k) |
294 | | - } |
295 | | - ) |
296 | | - else: |
297 | | - profiler_config_instance = profiler_config |
| 261 | + def _make_config(value: Any, cls: type[_R]) -> _R: |
| 262 | + """Convert dict/None/instance to a config instance.""" |
| 263 | + if value is None: |
| 264 | + return cls() |
| 265 | + if isinstance(value, dict): |
| 266 | + return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type] |
| 267 | + return value |
| 268 | + |
| 269 | + if isinstance(compilation_config, int): |
| 270 | + compilation_config_instance = CompilationConfig( |
| 271 | + mode=CompilationMode(compilation_config) |
| 272 | + ) |
298 | 273 | else: |
299 | | - profiler_config_instance = ProfilerConfig() |
| 274 | + compilation_config_instance = _make_config( |
| 275 | + compilation_config, CompilationConfig |
| 276 | + ) |
| 277 | + |
| 278 | + structured_outputs_instance = _make_config( |
| 279 | + structured_outputs_config, StructuredOutputsConfig |
| 280 | + ) |
| 281 | + profiler_config_instance = _make_config(profiler_config, ProfilerConfig) |
| 282 | + attention_config_instance = _make_config(attention_config, AttentionConfig) |
300 | 283 |
|
301 | 284 | # warn about single-process data parallel usage. |
302 | 285 | _dp_size = int(kwargs.get("data_parallel_size", 1)) |
@@ -341,6 +324,7 @@ def __init__( |
341 | 324 | pooler_config=pooler_config, |
342 | 325 | structured_outputs_config=structured_outputs_instance, |
343 | 326 | profiler_config=profiler_config_instance, |
| 327 | + attention_config=attention_config_instance, |
344 | 328 | compilation_config=compilation_config_instance, |
345 | 329 | logits_processors=logits_processors, |
346 | 330 | **kwargs, |
|
0 commit comments