Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions auto_round_extension/vllm_ext/auto_round_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class AutoRoundExtensionConfig(_BaseAutoRoundConfig):

def get_quant_method(self, layer: torch.nn.Module, prefix: str):
# FIXME: (yi) make it compatible with `AutoRoundConfig`
from vllm.attention.layer import Attention

if isinstance(layer, Attention):
from auto_round_extension.vllm_ext.kv_cache import AutoRoundKVCacheMethod

return AutoRoundKVCacheMethod(self)
if isinstance(layer, FusedMoE):
quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix)
return quant_method
Expand Down
38 changes: 38 additions & 0 deletions auto_round_extension/vllm_ext/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import TYPE_CHECKING, Any, Literal, Optional, cast

import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod

logger = init_logger(__name__)


class AutoRoundKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""

def __init__(self, quant_config):
self.validate_kv_cache_scheme(quant_config)
super().__init__(quant_config)

@staticmethod
def validate_kv_cache_scheme(quant_config):
# FIXME: parse from quant_config
return True
25 changes: 10 additions & 15 deletions auto_round_extension/vllm_ext/linear_impl_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def create_weights(

def process_weights_after_loading(self, layer) -> None:
# FIXME: may dequant to bf16
if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:

if envs.VLLM_MXFP4_PRE_UNPACK_TO_FP8:
weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8(
data_lp=layer.weight_packed,
scale_e8m0=layer.weight_scale,
Expand All @@ -110,20 +109,16 @@ def process_weights_after_loading(self, layer) -> None:
requires_grad=False,
),
)
else:
raise NotImplementedError("Only VLLM_MXFP4_PRE_UNPACK_TO_FP8 is supported now.")

def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
if not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:
out = run_mxfp4_emulations(x=x, weight=layer.weight_packed, weight_scale=layer.weight_scale)
if bias is not None:
out = out + bias
return out
else:
out = mxfp4_gemm_with_unpacked_weight(
x=x,
weight_fp8=layer.weight_unpacked_fp8,
weight_scale_bf16=layer.weight_scale_bf16,
bias=bias,
)
return out
out = mxfp4_gemm_with_unpacked_weight(
x=x,
weight_fp8=layer.weight_unpacked_fp8,
weight_scale_bf16=layer.weight_scale_bf16,
bias=bias,
)
return out
61 changes: 61 additions & 0 deletions auto_round_extension/vllm_ext/tests/test_fp8kv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from vllm.platforms import current_platform


def cuda_capability_at_least(major, minor):
device_capability = torch.cuda.get_device_capability()
return device_capability[0] >= major or (device_capability[0] == major and device_capability[1] >= minor)


MODELS = ["/home/yiliu7/workspace/auto-round/examples/Qwen2.5-0.5B-Instruct-ar-MXFP4-fp8"]


@pytest.fixture(autouse=True)
def set_vllm_ar_env(monkeypatch):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mengniwang95 @xin3he ,please be aware of the usage of FP8 kv.

monkeypatch.setenv("VLLM_AR_MXFP4_MODULAR_MOE", "1")
monkeypatch.setenv("VLLM_MXFP4_PRE_UNPACK_TO_FP8", "1")
monkeypatch.setenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0")
monkeypatch.setenv("VLLM_ENABLE_STATIC_MOE", "0")
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")
monkeypatch.setenv("VLLM_ENABLE_AR_EXT", "1")
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "1")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")


@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="only supports CUDA backend.",
)
@pytest.mark.skipif(
not cuda_capability_at_least(10, 0), reason="FP8 KV cache only supported on CUDA with compute capability >= 10.0"
)
@pytest.mark.parametrize("model", MODELS)
def test_auto_fp8_kv(vllm_runner, model):
with vllm_runner(
model,
# enforce_eager=True,
kv_cache_dtype="fp8",
gpu_memory_utilization=0.1,
) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=8)
assert (
llm.llm.llm_engine.engine_core.engine_core.model_executor.driver_worker.worker.model_runner.kv_cache_dtype
== torch.uint8
), f"Expected kv_cache_dtype to be torch.uint8, but got {llm.llm.llm_engine.engine_core.engine_core.model_executor.driver_worker.worker.model_runner.kv_cache_dtype}"
assert output
print(f"output is: {output[0][1]}")