Skip to content

Commit 6bde0e1

Browse files
authored
Add static FP8 attention support (#1061)
--------- Signed-off-by: yiliu30 <[email protected]>
1 parent ede64fa commit 6bde0e1

File tree

8 files changed

+316
-67
lines changed

8 files changed

+316
-67
lines changed

auto_round/__main__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,21 @@ def __init__(self, *args, **kwargs):
295295
help="List of layer names to keep in original precision (not quantized). "
296296
"Useful for preserving critical layers. Separate multiple names with commas.",
297297
)
298+
scheme.add_argument(
299+
"--static_kv_dtype",
300+
default=None,
301+
type=str,
302+
choices=["fp8", "float8_e4m3fn"],
303+
help="Data type for static quantize key and value. ",
304+
)
298305

306+
scheme.add_argument(
307+
"--static_attention_dtype",
308+
default=None,
309+
type=str,
310+
choices=["fp8", "float8_e4m3fn"],
311+
help="Data type for static quantize attention. ",
312+
)
299313
gguf = self.add_argument_group("Double Quant Arguments")
300314
gguf.add_argument(
301315
"--super_group_size", default=None, type=int, help="Super group size for double quantization."
@@ -556,6 +570,8 @@ def tune(args):
556570
super_group_size=args.super_group_size,
557571
quant_lm_head=args.quant_lm_head,
558572
fp_layers=args.fp_layers,
573+
static_kv_dtype=args.static_kv_dtype,
574+
static_attention_dtype=args.static_attention_dtype,
559575
)
560576
mllm_config = MLLMExtraConfig(
561577
quant_nontext_module=args.quant_nontext_module, extra_data_dir=args.extra_data_dir, template=args.template

auto_round/compressors/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def __init__(
237237
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
238238
self.momentum = kwargs.pop("momentum", 0.0)
239239
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
240+
static_attention_dtype = kwargs.pop("static_attention_dtype", None)
240241
model_dtype = kwargs.pop("model_dtype", None)
241242
device = kwargs.pop("device", None)
242243
if envs.AR_USE_MODELSCOPE:
@@ -356,6 +357,11 @@ def __init__(
356357
if self.static_kv_dtype is not None:
357358
logger.warning("The static kv is experimental and currently has limited support.")
358359

360+
# Attention static dtype
361+
self.static_attention_dtype = static_attention_dtype
362+
if self.static_attention_dtype is not None:
363+
logger.warning("The static attention dtype is experimental and currently has limited support.")
364+
359365
self._set_amp_dtype()
360366
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device
361367
if self.act_bits <= 8 and self.amp_dtype == torch.float16:
@@ -1004,7 +1010,12 @@ def quantize_and_save(
10041010
kwargs.pop("inplace", None)
10051011

10061012
# Perform model quantization
1007-
if self.static_kv_dtype is not None:
1013+
if self.static_attention_dtype is not None:
1014+
from auto_round.experimental.attention import attention_quant_ctx
1015+
1016+
with attention_quant_ctx(self.model, static_attention_dtype=self.static_attention_dtype):
1017+
model, _ = self.quantize()
1018+
elif self.static_kv_dtype is not None:
10081019
from auto_round.experimental.kv_cache import kvcache_quant_context
10091020

10101021
with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype):

auto_round/compressors/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ class SchemeExtraConfig(BaseExtraConfig):
275275
super_bits: int = None
276276
super_group_size: int = None
277277
static_kv_dtype: Union[str, torch.dtype] = None
278+
static_attention_dtype: Union[str, torch.dtype] = None
278279
quant_lm_head: bool = False
279280
fp_layers: str = None
280281

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# NOTICE: The design adapted from:
16+
# https://github.com/vllm-project/compressed-tensors/pull/491
17+
18+
19+
import contextlib
20+
import inspect
21+
from functools import partial
22+
from typing import Callable, Optional
23+
from weakref import ref
24+
25+
import torch
26+
from torch import Tensor
27+
from torch.nn import Module
28+
from torch.utils.hooks import RemovableHandle
29+
from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel
30+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
31+
32+
from auto_round.experimental.kv_cache import kvcache_quant_context
33+
from auto_round.experimental.utils import (
34+
is_attention_module,
35+
per_tensor_fp8_qdq,
36+
update_parameter_data,
37+
)
38+
from auto_round.utils import logger
39+
40+
__all__ = [
41+
"QuantizedAttentionImpl",
42+
"init_hooked_attention",
43+
"attention_quant_ctx",
44+
]
45+
46+
47+
ATTN_IMPL_ATTR_NAME = "impl"
48+
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
49+
QUERY_SCALE_NAME = "q_scale"
50+
QUERY_MAX_NAME = "q_max"
51+
52+
53+
class QuantizedAttentionImpl(torch.nn.Module):
54+
"""
55+
QuantizedAttentionImpl module which wraps the functionality of the original
56+
attention implementation. Unlike the original attention function, this
57+
implementation is a `torch.nn.Module` which can be hooked to trigger
58+
transforms and calibration hooks.
59+
60+
This module works by being registered as a submodule to attention modules via
61+
`init_hooked_attention`, registering a new attention implementation function
62+
which calls this module, then setting the model attention implementation to the new
63+
function. After triggering hooks and quantization, this module calls the original
64+
attention implementation function.
65+
66+
:param attn_module: parent attention module
67+
"""
68+
69+
_original_impl = "sdpa"
70+
71+
def __init__(self, config: PretrainedConfig, attn_module: Module):
72+
super().__init__()
73+
self.config = config
74+
self.attn_module = ref(attn_module) # avoid circular references
75+
# register query max
76+
device = next(attn_module.parameters()).device
77+
initial_max = torch.tensor([float("-inf")], device=device)
78+
update_parameter_data(attn_module, initial_max, QUERY_MAX_NAME)
79+
initial_scale = torch.tensor([0.0], device=device)
80+
update_parameter_data(attn_module, initial_scale, QUERY_SCALE_NAME)
81+
82+
def forward(
83+
self,
84+
module: Module,
85+
query: Tensor,
86+
key: Tensor,
87+
value: Tensor,
88+
*args,
89+
**kwargs,
90+
):
91+
cur_query_max = query.abs().max()
92+
query_max = torch.max(
93+
getattr(module, QUERY_MAX_NAME).data,
94+
cur_query_max.detach().to(getattr(module, QUERY_MAX_NAME).data.device),
95+
)
96+
update_parameter_data(module, query_max, QUERY_MAX_NAME)
97+
query, query_scale = per_tensor_fp8_qdq(query, tensor_max=query_max)
98+
update_parameter_data(module, query_scale.squeeze(0), QUERY_SCALE_NAME)
99+
# original attention
100+
return ALL_ATTENTION_FUNCTIONS[self._original_impl](
101+
module,
102+
query,
103+
key,
104+
value,
105+
*args,
106+
**kwargs,
107+
)
108+
109+
110+
# ----- initialize ----- #
111+
112+
113+
def _ct_hooked_attention(module: Module, *args, **kwargs):
114+
if hasattr(module, ATTN_IMPL_ATTR_NAME):
115+
return module.impl(module, *args, **kwargs)
116+
else:
117+
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) # pylint: disable=E0601
118+
119+
120+
def init_hooked_attention(module: Module, config):
121+
"""
122+
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
123+
attached to attention
124+
125+
:param model: parent model of attention module
126+
:param module: attention module to initialize with
127+
"""
128+
if not hasattr(module, ATTN_IMPL_ATTR_NAME):
129+
module.register_module(ATTN_IMPL_ATTR_NAME, QuantizedAttentionImpl(config, module))
130+
if config._attn_implementation != HOOKED_ATTENTION_NAME:
131+
# assumes only one model at a time
132+
global _original_impl
133+
_original_impl = config._attn_implementation
134+
# Add new implementation to AttentionInterface(mapping)
135+
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
136+
config._attn_implementation = HOOKED_ATTENTION_NAME
137+
138+
# initialize_hooked_kv_cache(model, module)
139+
140+
141+
def prep_attention_module_for_calibration(module: torch.nn.Module, config):
142+
if is_attention_module(module):
143+
logger.trace(f"Preparing attention module {module.__class__.__name__} for calibration")
144+
init_hooked_attention(module, config)
145+
146+
147+
def clean_up_hooked_attention(module, model):
148+
if is_attention_module(module):
149+
# Cleanup phase: Restore the original attention implementation
150+
if hasattr(model.config, "_attn_implementation") and hasattr(model, "_original_impl"):
151+
model.config._attn_implementation = model._original_impl
152+
del model._original_impl
153+
154+
155+
@contextlib.contextmanager
156+
def attention_quant_ctx(model: PreTrainedModel, static_attention_dtype=torch.float8_e4m3fn):
157+
try:
158+
# Setup phase: Initialize hooked attention
159+
prepare_fn = partial(prep_attention_module_for_calibration, config=model.config)
160+
model.apply(prepare_fn)
161+
with kvcache_quant_context(model, static_kv_dtype=static_attention_dtype):
162+
yield model
163+
finally:
164+
clean_fn = partial(clean_up_hooked_attention, model=model)
165+
model.apply(clean_fn)

auto_round/experimental/kv_cache.py

Lines changed: 11 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
import torch
2525
from transformers.cache_utils import DynamicCache
2626

27+
from auto_round.experimental.utils import (
28+
is_attention_module,
29+
normalize_static_kv_dtype,
30+
per_tensor_fp8_qdq,
31+
update_parameter_data,
32+
)
2733
from auto_round.utils import logger
2834

2935
__all__ = [
@@ -81,13 +87,6 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
8187
return lst
8288

8389

84-
def fp8_per_tensor_qdq(tensor):
85-
from auto_round.data_type.fp8 import quant_fp8_sym
86-
87-
qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0)
88-
return qdq_tensor, scale
89-
90-
9190
class QuantizedKVParameterCache(DynamicCache):
9291
"""
9392
Quantized KV cache used in the forward call based on HF's dynamic cache.
@@ -173,8 +172,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_
173172
assert kv_type == KVCacheScaleType.VALUE
174173
scales = self.v_scales
175174

176-
qdq_tensor, scale = fp8_per_tensor_qdq(tensor)
177-
_pad_and_append_at_idx_(scales, layer_idx, scale)
175+
qdq_tensor, scale = per_tensor_fp8_qdq(tensor)
176+
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0))
178177
return qdq_tensor
179178

180179

@@ -192,13 +191,9 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4
192191
quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype)
193192
setattr(module, "kv_cache", quantized_kv_cache)
194193
logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
195-
196-
197-
def is_attention_module(module: torch.nn.Module):
198-
# FIXME: Handle this better.
199-
return "attention" in module.__class__.__name__.lower() and (
200-
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj")
201-
)
194+
init_scale = torch.tensor([0.0], device=next(module.parameters()).device)
195+
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value)
196+
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value)
202197

203198

204199
def calibrate_kv_cache_input_hook(
@@ -209,7 +204,6 @@ def calibrate_kv_cache_input_hook(
209204
kv_cache quantization. Will update the passed in
210205
kv_cache to singleton QuantizedKVParameterCache.
211206
"""
212-
logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
213207
kv_cache = getattr(module, "kv_cache")
214208
# Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`.
215209
# https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280
@@ -221,33 +215,10 @@ def calibrate_kv_cache_input_hook(
221215
return args, kwargs
222216

223217

224-
def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str):
225-
"""
226-
Update the data of a parameter in a module.
227-
If the parameter does not exist, it will be created.
228-
"""
229-
if hasattr(module, name):
230-
param = getattr(module, name)
231-
if isinstance(param, torch.nn.Parameter):
232-
param.data = new_val
233-
else:
234-
module.register_parameter(name, torch.nn.Parameter(new_val))
235-
else:
236-
logger.warning(
237-
"Parameter %s not found in module %s, creating new parameter."
238-
% (name, module.__class__.__name__ + str(getattr(module, "layer_idx", "")))
239-
)
240-
module.register_parameter(name, torch.nn.Parameter(new_val))
241-
242-
243218
def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor):
244219
"""
245220
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
246221
"""
247-
logger.debug(
248-
"Calibrate kv_cache output hook for %s %s"
249-
% (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
250-
)
251222
kv_cache = getattr(module, "kv_cache")
252223
k_scale = kv_cache.k_scales[module.layer_idx]
253224
v_scale = kv_cache.v_scales[module.layer_idx]
@@ -261,28 +232,6 @@ def prep_attention_module_for_calibration(module: torch.nn.Module):
261232
module.register_forward_hook(calibrate_kv_cache_output_hook)
262233

263234

264-
def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype:
265-
valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"]
266-
valid_torch_dtype = {
267-
"float16": torch.float16,
268-
"bfloat16": torch.bfloat16,
269-
"fp8": torch.float8_e4m3fn,
270-
"float8_e4m3fn": torch.float8_e4m3fn,
271-
"float32": torch.float32,
272-
"float": torch.float32, # Alias for float32
273-
}
274-
if static_kv_dtype in valid_dtype_name_lst:
275-
new_dtype = valid_torch_dtype[static_kv_dtype]
276-
elif static_kv_dtype in valid_torch_dtype.values():
277-
new_dtype = static_kv_dtype
278-
else:
279-
raise ValueError(
280-
f"Invalid static kv dtype: {static_kv_dtype}. "
281-
f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}."
282-
)
283-
return new_dtype
284-
285-
286235
@contextlib.contextmanager
287236
def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn):
288237
"""Context manager for FP8 KV cache quantization operations."""

auto_round/experimental/qmodules/fp8_static.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def qdq_input(self, bf16_input: torch.Tensor):
115115

116116
@torch.no_grad()
117117
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
118-
118+
original_dtype = bf16_input.dtype
119119
qdq_input = self.qdq_input(bf16_input)
120120
qdq_weight = self.dequant_weight_online()
121-
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
121+
out = torch.nn.functional.linear(qdq_input.to(original_dtype), qdq_weight.to(original_dtype), self.bias)
122122
return out

0 commit comments

Comments
 (0)