Skip to content

Commit a182be4

Browse files
[UX][Attention] Add attention_config argument to LLM() (vllm-project#30710)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
1 parent c01d589 commit a182be4

1 file changed

Lines changed: 28 additions & 44 deletions

File tree

vllm/entrypoints/llm.py

Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
create_sort_beams_key_function,
1919
)
2020
from vllm.config import (
21+
AttentionConfig,
2122
CompilationConfig,
2223
PoolerConfig,
2324
ProfilerConfig,
@@ -175,6 +176,10 @@ class LLM:
175176
compilation_config: Either an integer or a dictionary. If it is an
176177
integer, it is used as the mode of compilation optimization. If it
177178
is a dictionary, it can specify the full compilation configuration.
179+
attention_config: Configuration for attention mechanisms. Can be a
180+
dictionary or an AttentionConfig instance. If a dictionary, it will
181+
be converted to an AttentionConfig. Allows specifying the attention
182+
backend and other attention-related settings.
178183
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
179184
180185
Note:
@@ -213,6 +218,7 @@ def __init__(
213218
| StructuredOutputsConfig
214219
| None = None,
215220
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
221+
attention_config: dict[str, Any] | AttentionConfig | None = None,
216222
kv_cache_memory_bytes: int | None = None,
217223
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
218224
logits_processors: list[str | type[LogitsProcessor]] | None = None,
@@ -252,51 +258,28 @@ def __init__(
252258
if hf_overrides is None:
253259
hf_overrides = {}
254260

255-
if compilation_config is not None:
256-
if isinstance(compilation_config, int):
257-
compilation_config_instance = CompilationConfig(
258-
mode=CompilationMode(compilation_config)
259-
)
260-
elif isinstance(compilation_config, dict):
261-
compilation_config_instance = CompilationConfig(
262-
**{
263-
k: v
264-
for k, v in compilation_config.items()
265-
if is_init_field(CompilationConfig, k)
266-
}
267-
)
268-
else:
269-
compilation_config_instance = compilation_config
270-
else:
271-
compilation_config_instance = CompilationConfig()
272-
273-
if structured_outputs_config is not None:
274-
if isinstance(structured_outputs_config, dict):
275-
structured_outputs_instance = StructuredOutputsConfig(
276-
**{
277-
k: v
278-
for k, v in structured_outputs_config.items()
279-
if is_init_field(StructuredOutputsConfig, k)
280-
}
281-
)
282-
else:
283-
structured_outputs_instance = structured_outputs_config
284-
else:
285-
structured_outputs_instance = StructuredOutputsConfig()
286-
287-
if profiler_config is not None:
288-
if isinstance(profiler_config, dict):
289-
profiler_config_instance = ProfilerConfig(
290-
**{
291-
k: v
292-
for k, v in profiler_config.items()
293-
if is_init_field(ProfilerConfig, k)
294-
}
295-
)
296-
else:
297-
profiler_config_instance = profiler_config
261+
def _make_config(value: Any, cls: type[_R]) -> _R:
262+
"""Convert dict/None/instance to a config instance."""
263+
if value is None:
264+
return cls()
265+
if isinstance(value, dict):
266+
return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type]
267+
return value
268+
269+
if isinstance(compilation_config, int):
270+
compilation_config_instance = CompilationConfig(
271+
mode=CompilationMode(compilation_config)
272+
)
298273
else:
299-
profiler_config_instance = ProfilerConfig()
274+
compilation_config_instance = _make_config(
275+
compilation_config, CompilationConfig
276+
)
277+
278+
structured_outputs_instance = _make_config(
279+
structured_outputs_config, StructuredOutputsConfig
280+
)
281+
profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
282+
attention_config_instance = _make_config(attention_config, AttentionConfig)
300283

301284
# warn about single-process data parallel usage.
302285
_dp_size = int(kwargs.get("data_parallel_size", 1))
@@ -341,6 +324,7 @@ def __init__(
341324
pooler_config=pooler_config,
342325
structured_outputs_config=structured_outputs_instance,
343326
profiler_config=profiler_config_instance,
327+
attention_config=attention_config_instance,
344328
compilation_config=compilation_config_instance,
345329
logits_processors=logits_processors,
346330
**kwargs,

0 commit comments

Comments
 (0)