Skip to content

Commit 98e7f22

Browse files
authored
enable skipping of SW attention layers when using FP8 KV cache (vllm-project#33695)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
1 parent b111f8a commit 98e7f22

File tree

4 files changed

+58
-0
lines changed

4 files changed

+58
-0
lines changed

tests/quantization/test_fp8.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,26 @@ def test_fp8_reloading(
466466
weight_loader(param, torch.zeros(shape)) # cannot use empty
467467

468468
method.process_weights_after_loading(layer)
469+
470+
471+
@pytest.mark.skipif(
472+
not is_quant_method_supported("fp8"),
473+
reason="FP8 is not supported on this GPU type.",
474+
)
475+
def test_kv_cache_dtype_skip_layers(vllm_runner, monkeypatch):
476+
"""Test that kv_cache_dtype_skip_layers skips quantization for specified layers."""
477+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
478+
479+
with vllm_runner(
480+
"facebook/opt-125m",
481+
kv_cache_dtype="fp8",
482+
kv_cache_dtype_skip_layers=["0", "2"],
483+
enforce_eager=True,
484+
) as llm:
485+
486+
def check_layers(model):
487+
for i, layer in enumerate(model.model.decoder.layers):
488+
expected = "auto" if str(i) in ["0", "2"] else "fp8"
489+
assert layer.self_attn.attn.kv_cache_dtype == expected
490+
491+
llm.apply_model(check_layers)

vllm/config/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ class CacheConfig:
8787
It enables dynamic calculation of `k_scale` and `v_scale` when
8888
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
8989
checkpoint if available. Otherwise, the scales will default to 1.0."""
90+
kv_cache_dtype_skip_layers: list[str] = field(default_factory=list)
91+
"""Layer patterns to skip KV cache quantization. Accepts layer indices
92+
(e.g., '0', '2', '4') or attention type names (e.g., 'sliding_window')."""
9093
cpu_kvcache_space_bytes: int | None = None
9194
"""(CPU backend only) CPU key-value cache space."""
9295
mamba_page_size_padded: int | None = None

vllm/engine/arg_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ class EngineArgs:
597597
attention_backend: AttentionBackendEnum | None = AttentionConfig.backend
598598

599599
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
600+
kv_cache_dtype_skip_layers: list[str] = get_field(
601+
CacheConfig, "kv_cache_dtype_skip_layers"
602+
)
600603
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
601604
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
602605
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
@@ -1003,6 +1006,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10031006
cache_group.add_argument(
10041007
"--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
10051008
)
1009+
cache_group.add_argument(
1010+
"--kv-cache-dtype-skip-layers", **cache_kwargs["kv_cache_dtype_skip_layers"]
1011+
)
10061012
cache_group.add_argument(
10071013
"--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]
10081014
)
@@ -1578,6 +1584,7 @@ def create_engine_config(
15781584
enable_prefix_caching=self.enable_prefix_caching,
15791585
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
15801586
calculate_kv_scales=self.calculate_kv_scales,
1587+
kv_cache_dtype_skip_layers=self.kv_cache_dtype_skip_layers,
15811588
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
15821589
mamba_cache_dtype=self.mamba_cache_dtype,
15831590
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,

vllm/model_executor/layers/attention/attention.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,31 @@ def __init__(
240240
and kv_cache_scheme.get("strategy") == "attn_head"
241241
)
242242

243+
# Skip quantization for specified layers
244+
if cache_config is not None and cache_config.kv_cache_dtype_skip_layers:
245+
from vllm.model_executor.models.utils import extract_layer_index
246+
247+
skip = False
248+
# Check attention type
249+
if (
250+
sliding_window is not None
251+
and "sliding_window" in cache_config.kv_cache_dtype_skip_layers
252+
):
253+
skip = True
254+
# Check layer index
255+
layer_idx = extract_layer_index(prefix)
256+
if str(layer_idx) in cache_config.kv_cache_dtype_skip_layers:
257+
skip = True
258+
if skip:
259+
kv_cache_dtype = "auto"
260+
calculate_kv_scales = False
261+
logger.info(
262+
"Layer %s: kv_cache_dtype=%s, sliding_window=%s",
263+
prefix,
264+
kv_cache_dtype,
265+
sliding_window,
266+
)
267+
243268
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
244269
kv_cache_dtype, vllm_config.model_config
245270
)

0 commit comments

Comments
 (0)