diff --git a/auto_round_extension/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py index d665fd568..8ef85069a 100644 --- a/auto_round_extension/vllm_ext/auto_round_ext.py +++ b/auto_round_extension/vllm_ext/auto_round_ext.py @@ -33,6 +33,12 @@ class AutoRoundExtensionConfig(_BaseAutoRoundConfig): def get_quant_method(self, layer: torch.nn.Module, prefix: str): # FIXME: (yi) make it compatible with `AutoRoundConfig` + from vllm.attention.layer import Attention + + if isinstance(layer, Attention): + from auto_round_extension.vllm_ext.kv_cache import AutoRoundKVCacheMethod + + return AutoRoundKVCacheMethod(self) if isinstance(layer, FusedMoE): quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix) return quant_method diff --git a/auto_round_extension/vllm_ext/kv_cache.py b/auto_round_extension/vllm_ext/kv_cache.py new file mode 100644 index 000000000..ec2b7e179 --- /dev/null +++ b/auto_round_extension/vllm_ext/kv_cache.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING, Any, Literal, Optional, cast + +import torch +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + +logger = init_logger(__name__) + + +class AutoRoundKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from compressed-tensors + checkpoints. + """ + + def __init__(self, quant_config): + self.validate_kv_cache_scheme(quant_config) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_scheme(quant_config): + # FIXME: parse from quant_config + return True diff --git a/auto_round_extension/vllm_ext/linear_impl_mxfp4.py b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py index 04d5e20f8..e544fdbbf 100644 --- a/auto_round_extension/vllm_ext/linear_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/linear_impl_mxfp4.py @@ -86,8 +86,7 @@ def create_weights( def process_weights_after_loading(self, layer) -> None: # FIXME: may dequant to bf16 - if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: - + if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8: weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8( data_lp=layer.weight_packed, scale_e8m0=layer.weight_scale, @@ -110,20 +109,16 @@ def process_weights_after_loading(self, layer) -> None: requires_grad=False, ), ) + else: + raise NotImplementedError("Only VLLM_MXFP4_PRE_UNPACK_TO_FP8 is supported now.") def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None ) -> torch.Tensor: - if not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: - out = run_mxfp4_emulations(x=x, weight=layer.weight_packed, weight_scale=layer.weight_scale) - if bias is not None: - out = out + bias - return out - else: - out = mxfp4_gemm_with_unpacked_weight( - x=x, - weight_fp8=layer.weight_unpacked_fp8, - weight_scale_bf16=layer.weight_scale_bf16, - bias=bias, - ) - return out + out = mxfp4_gemm_with_unpacked_weight( + x=x, + weight_fp8=layer.weight_unpacked_fp8, + weight_scale_bf16=layer.weight_scale_bf16, + bias=bias, + ) + return out diff --git a/auto_round_extension/vllm_ext/tests/test_fp8kv.py b/auto_round_extension/vllm_ext/tests/test_fp8kv.py new file mode 100644 index 000000000..709308f9a --- /dev/null +++ b/auto_round_extension/vllm_ext/tests/test_fp8kv.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from vllm.platforms import current_platform + + +def cuda_capability_at_least(major, minor): + device_capability = torch.cuda.get_device_capability() + return device_capability[0] >= major or (device_capability[0] == major and device_capability[1] >= minor) + + +MODELS = ["/home/yiliu7/workspace/auto-round/examples/Qwen2.5-0.5B-Instruct-ar-MXFP4-fp8"] + + +@pytest.fixture(autouse=True) +def set_vllm_ar_env(monkeypatch): + monkeypatch.setenv("VLLM_AR_MXFP4_MODULAR_MOE", "1") + monkeypatch.setenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "1") + monkeypatch.setenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0") + monkeypatch.setenv("VLLM_ENABLE_STATIC_MOE", "0") + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0") + monkeypatch.setenv("VLLM_ENABLE_AR_EXT", "1") + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + monkeypatch.setenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "1") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="only supports CUDA backend.", +) +@pytest.mark.skipif( + not cuda_capability_at_least(10, 0), reason="FP8 KV cache only supported on CUDA with compute capability >= 10.0" +) +@pytest.mark.parametrize("model", MODELS) +def test_auto_fp8_kv(vllm_runner, model): + with vllm_runner( + model, + # enforce_eager=True, + kv_cache_dtype="fp8", + gpu_memory_utilization=0.1, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=8) + assert ( + llm.llm.llm_engine.engine_core.engine_core.model_executor.driver_worker.worker.model_runner.kv_cache_dtype + == torch.uint8 + ), f"Expected kv_cache_dtype to be torch.uint8, but got {llm.llm.llm_engine.engine_core.engine_core.model_executor.driver_worker.worker.model_runner.kv_cache_dtype}" + assert output + print(f"output is: {output[0][1]}")