Skip to content

Commit c5b1c41

Browse files
authored
Add static FP8 attention support (#1045)
Signed-off-by: yiliu30 <[email protected]>
1 parent e1b89d2 commit c5b1c41

File tree

7 files changed

+338
-68
lines changed

7 files changed

+338
-68
lines changed

auto_round/__main__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def __init__(self, *args, **kwargs):
153153
basic.add_argument(
154154
"--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. "
155155
)
156+
basic.add_argument(
157+
"--static_kv_dtype", default=None, type=str, help="Data type for static quantize key and value. "
158+
)
159+
160+
basic.add_argument(
161+
"--static_attention_dtype ", default=None, type=str, help="Data type for static quantize attention. "
162+
)
156163

157164
tuning = self.add_argument_group("Tuning Arguments")
158165
tuning.add_argument(
@@ -599,6 +606,8 @@ def tune(args):
599606
layer_config=layer_config,
600607
model_dtype=args.model_dtype,
601608
momentum=args.momentum,
609+
static_kv_dtype=args.static_kv_dtype,
610+
static_attention_dtype=args.static_attention_dtype,
602611
)
603612

604613
model_name = args.model.rstrip("/")

auto_round/compressors/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def __init__(
237237
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
238238
self.momentum = kwargs.pop("momentum", 0.0)
239239
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
240+
static_attention_dtype = kwargs.pop("static_attention_dtype", None)
240241
model_dtype = kwargs.pop("model_dtype", None)
241242
device = kwargs.pop("device", None)
242243
if envs.AR_USE_MODELSCOPE:
@@ -356,6 +357,11 @@ def __init__(
356357
if self.static_kv_dtype is not None:
357358
logger.warning("The static kv is experimental and currently has limited support.")
358359

360+
# Attention static dtype
361+
self.static_attention_dtype = static_attention_dtype
362+
if self.static_attention_dtype is not None:
363+
logger.warning("The static attention dtype is experimental and currently has limited support.")
364+
359365
self._set_amp_dtype()
360366
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device
361367
if self.act_bits <= 8 and self.amp_dtype == torch.float16:
@@ -1004,7 +1010,12 @@ def quantize_and_save(
10041010
kwargs.pop("inplace", None)
10051011

10061012
# Perform model quantization
1007-
if self.static_kv_dtype is not None:
1013+
if self.static_attention_dtype is not None:
1014+
from auto_round.experimental.attention import attention_quant_ctx
1015+
1016+
with attention_quant_ctx(self.model, static_attention_dtype=self.static_attention_dtype):
1017+
model, _ = self.quantize()
1018+
elif self.static_kv_dtype is not None:
10081019
from auto_round.experimental.kv_cache import kvcache_quant_context
10091020

10101021
with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype):
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# NOTICE: The design adapted from:
16+
# https://github.com/vllm-project/compressed-tensors/pull/491
17+
18+
19+
import contextlib
20+
import inspect
21+
from functools import partial
22+
from typing import Callable, Optional
23+
from weakref import ref
24+
25+
import torch
26+
from torch import Tensor
27+
from torch.nn import Module
28+
from torch.utils.hooks import RemovableHandle
29+
from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel
30+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
31+
32+
from auto_round.experimental.kv_cache import kvcache_quant_context
33+
from auto_round.experimental.utils import (
34+
is_attention_module,
35+
per_tensor_fp8_qdq,
36+
update_parameter_data,
37+
)
38+
from auto_round.utils import logger
39+
40+
__all__ = [
41+
"QuantizedAttentionImpl",
42+
"initialize_hooked_attention",
43+
"IMPL_ATTR",
44+
"attention_quant_ctx",
45+
]
46+
47+
48+
IMPL_ATTR = "impl"
49+
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
50+
QUERY_SCALE_NAME = "q_scale"
51+
QUERY_MAX_NAME = "q_max"
52+
53+
54+
class QuantizedAttentionImpl(torch.nn.Module):
55+
"""
56+
QuantizedAttentionImpl module which wraps the functionality of the original
57+
attention implementation. Unlike the original attention function, this
58+
implementation is a `torch.nn.Module` which can be hooked to trigger
59+
transforms and calibration hooks.
60+
61+
This module works by being registered as a submodule to attention modules via
62+
`initialize_hooked_attention`, registering a new attention implementation function
63+
which calls this module, then setting the model attention implementation to the new
64+
function. After triggering hooks and quantization, this module calls the original
65+
attention implementation function.
66+
67+
:param attn_module: parent attention module
68+
"""
69+
70+
_original_impl = "sdpa"
71+
72+
def __init__(self, config: PretrainedConfig, attn_module: Module):
73+
super().__init__()
74+
self.config = config
75+
self.attn_module = ref(attn_module) # avoid circular references
76+
# register query max
77+
device = next(attn_module.parameters()).device
78+
initial_max = torch.tensor([float("-inf")], device=device)
79+
update_parameter_data(attn_module, initial_max, QUERY_MAX_NAME)
80+
initial_scale = torch.tensor([0.0], device=device)
81+
update_parameter_data(attn_module, initial_scale, QUERY_SCALE_NAME)
82+
83+
def forward(
84+
self,
85+
module: Module,
86+
query: Tensor,
87+
key: Tensor,
88+
value: Tensor,
89+
*args,
90+
**kwargs,
91+
):
92+
cur_query_max = query.abs().max()
93+
query_max = torch.max(
94+
getattr(module, QUERY_MAX_NAME).data,
95+
cur_query_max.detach().to(getattr(module, QUERY_MAX_NAME).data.device),
96+
)
97+
update_parameter_data(module, query_max, QUERY_MAX_NAME)
98+
query, query_scale = per_tensor_fp8_qdq(query, tensor_max=query_max)
99+
update_parameter_data(module, query_scale.squeeze(0), QUERY_SCALE_NAME)
100+
# original attention
101+
return ALL_ATTENTION_FUNCTIONS[self._original_impl](
102+
module,
103+
query,
104+
key,
105+
value,
106+
*args,
107+
**kwargs,
108+
)
109+
110+
111+
# ----- initialize ----- #
112+
113+
114+
def _ct_hooked_attention(module: Module, *args, **kwargs):
115+
if hasattr(module, IMPL_ATTR):
116+
return module.impl(module, *args, **kwargs)
117+
else:
118+
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) # pylint: disable=E0601
119+
120+
121+
def initialize_hooked_attention(module: Module, config):
122+
"""
123+
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
124+
attached to attention
125+
126+
:param model: parent model of attention module
127+
:param module: attention module to initialize with
128+
"""
129+
if not hasattr(module, IMPL_ATTR):
130+
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(config, module))
131+
if config._attn_implementation != HOOKED_ATTENTION_NAME:
132+
# assumes only one model at a time
133+
global _original_impl
134+
_original_impl = config._attn_implementation
135+
# Add new implementation to AttentionInterface(mapping)
136+
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
137+
config._attn_implementation = HOOKED_ATTENTION_NAME
138+
139+
# initialize_hooked_kv_cache(model, module)
140+
141+
142+
def prep_attention_module_for_calibration(module: torch.nn.Module, config):
143+
if is_attention_module(module):
144+
logger.trace(f"Preparing attention module {module.__class__.__name__} for calibration")
145+
initialize_hooked_attention(module, config)
146+
147+
148+
# # ----- hooks ----- #
149+
150+
151+
# def register_query_hook(module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]) -> RemovableHandle:
152+
# """
153+
# Register a hook which takes post-rope query states as an argument and
154+
# returns the modified query states or `None`
155+
156+
# :param module: attention module to add hook to
157+
# :param hook: query hook function
158+
# """
159+
# impl = getattr(module, IMPL_ATTR)
160+
161+
# def _hook(impl: QuantizedAttentionImpl, args, kwargs):
162+
# bound = inspect.signature(impl.forward).bind(*args, **kwargs)
163+
# value = hook(module, bound.arguments["query"])
164+
# if value is not None:
165+
# bound.arguments["query"] = value
166+
167+
# return bound.args, bound.kwargs
168+
169+
# return impl.register_forward_pre_hook(_hook, with_kwargs=True)
170+
171+
172+
def clean_up_hooked_attention(module, model):
173+
if is_attention_module(module):
174+
# Cleanup phase: Restore the original attention implementation
175+
if hasattr(model.config, "_attn_implementation") and hasattr(model, "_original_impl"):
176+
model.config._attn_implementation = model._original_impl
177+
del model._original_impl
178+
179+
180+
@contextlib.contextmanager
181+
def attention_quant_ctx(model: PreTrainedModel, static_attention_dtype=torch.float8_e4m3fn):
182+
try:
183+
# Setup phase: Initialize hooked attention
184+
prepare_fn = partial(prep_attention_module_for_calibration, config=model.config)
185+
model.apply(prepare_fn)
186+
with kvcache_quant_context(model, static_kv_dtype=static_attention_dtype):
187+
yield model
188+
finally:
189+
clean_fn = partial(clean_up_hooked_attention, model=model)
190+
model.apply(clean_fn)

auto_round/experimental/kv_cache.py

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
import torch
2525
from transformers.cache_utils import DynamicCache
2626

27+
from auto_round.experimental.utils import (
28+
is_attention_module,
29+
normalize_static_kv_dtype,
30+
per_tensor_fp8_qdq,
31+
update_parameter_data,
32+
)
2733
from auto_round.utils import logger
2834

2935
__all__ = [
@@ -81,13 +87,6 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
8187
return lst
8288

8389

84-
def fp8_per_tensor_qdq(tensor):
85-
from auto_round.data_type.fp8 import quant_fp8_sym
86-
87-
qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0)
88-
return qdq_tensor, scale
89-
90-
9190
class QuantizedKVParameterCache(DynamicCache):
9291
"""
9392
Quantized KV cache used in the forward call based on HF's dynamic cache.
@@ -173,8 +172,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_
173172
assert kv_type == KVCacheScaleType.VALUE
174173
scales = self.v_scales
175174

176-
qdq_tensor, scale = fp8_per_tensor_qdq(tensor)
177-
_pad_and_append_at_idx_(scales, layer_idx, scale)
175+
qdq_tensor, scale = per_tensor_fp8_qdq(tensor)
176+
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0))
178177
return qdq_tensor
179178

180179

@@ -192,13 +191,9 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4
192191
quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype)
193192
setattr(module, "kv_cache", quantized_kv_cache)
194193
logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
195-
196-
197-
def is_attention_module(module: torch.nn.Module):
198-
# FIXME: Handle this better.
199-
return "attention" in module.__class__.__name__.lower() and (
200-
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj")
201-
)
194+
init_scale = torch.tensor([0.0], device=next(module.parameters()).device)
195+
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value)
196+
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value)
202197

203198

204199
def calibrate_kv_cache_input_hook(
@@ -209,7 +204,6 @@ def calibrate_kv_cache_input_hook(
209204
kv_cache quantization. Will update the passed in
210205
kv_cache to singleton QuantizedKVParameterCache.
211206
"""
212-
logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
213207
kv_cache = getattr(module, "kv_cache")
214208
# Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`.
215209
# https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280
@@ -221,33 +215,14 @@ def calibrate_kv_cache_input_hook(
221215
return args, kwargs
222216

223217

224-
def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str):
225-
"""
226-
Update the data of a parameter in a module.
227-
If the parameter does not exist, it will be created.
228-
"""
229-
if hasattr(module, name):
230-
param = getattr(module, name)
231-
if isinstance(param, torch.nn.Parameter):
232-
param.data = new_val
233-
else:
234-
module.register_parameter(name, torch.nn.Parameter(new_val))
235-
else:
236-
logger.warning(
237-
"Parameter %s not found in module %s, creating new parameter."
238-
% (name, module.__class__.__name__ + str(getattr(module, "layer_idx", "")))
239-
)
240-
module.register_parameter(name, torch.nn.Parameter(new_val))
241-
242-
243218
def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor):
244219
"""
245220
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
246221
"""
247-
logger.debug(
248-
"Calibrate kv_cache output hook for %s %s"
249-
% (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
250-
)
222+
# logger.debug(
223+
# "Calibrate kv_cache output hook for %s %s"
224+
# % (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
225+
# )
251226
kv_cache = getattr(module, "kv_cache")
252227
k_scale = kv_cache.k_scales[module.layer_idx]
253228
v_scale = kv_cache.v_scales[module.layer_idx]
@@ -261,28 +236,6 @@ def prep_attention_module_for_calibration(module: torch.nn.Module):
261236
module.register_forward_hook(calibrate_kv_cache_output_hook)
262237

263238

264-
def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype:
265-
valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"]
266-
valid_torch_dtype = {
267-
"float16": torch.float16,
268-
"bfloat16": torch.bfloat16,
269-
"fp8": torch.float8_e4m3fn,
270-
"float8_e4m3fn": torch.float8_e4m3fn,
271-
"float32": torch.float32,
272-
"float": torch.float32, # Alias for float32
273-
}
274-
if static_kv_dtype in valid_dtype_name_lst:
275-
new_dtype = valid_torch_dtype[static_kv_dtype]
276-
elif static_kv_dtype in valid_torch_dtype.values():
277-
new_dtype = static_kv_dtype
278-
else:
279-
raise ValueError(
280-
f"Invalid static kv dtype: {static_kv_dtype}. "
281-
f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}."
282-
)
283-
return new_dtype
284-
285-
286239
@contextlib.contextmanager
287240
def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn):
288241
"""Context manager for FP8 KV cache quantization operations."""

auto_round/experimental/qmodules/fp8_static.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def qdq_input(self, bf16_input: torch.Tensor):
115115

116116
@torch.no_grad()
117117
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
118-
118+
original_dtype = bf16_input.dtype
119119
qdq_input = self.qdq_input(bf16_input)
120120
qdq_weight = self.dequant_weight_online()
121-
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
121+
out = torch.nn.functional.linear(qdq_input.to(original_dtype), qdq_weight.to(original_dtype), self.bias)
122122
return out

0 commit comments

Comments
 (0)