Skip to content

Commit 6d40e20

Browse files
authored
Revert "Add static FP8 attention support (#1045)" (#1060)
This reverts commit c5b1c41.
1 parent 4ec50d1 commit 6d40e20

File tree

7 files changed

+68
-338
lines changed

7 files changed

+68
-338
lines changed

auto_round/__main__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,6 @@ def __init__(self, *args, **kwargs):
153153
basic.add_argument(
154154
"--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. "
155155
)
156-
basic.add_argument(
157-
"--static_kv_dtype", default=None, type=str, help="Data type for static quantize key and value. "
158-
)
159-
160-
basic.add_argument(
161-
"--static_attention_dtype ", default=None, type=str, help="Data type for static quantize attention. "
162-
)
163156

164157
tuning = self.add_argument_group("Tuning Arguments")
165158
tuning.add_argument(
@@ -606,8 +599,6 @@ def tune(args):
606599
layer_config=layer_config,
607600
model_dtype=args.model_dtype,
608601
momentum=args.momentum,
609-
static_kv_dtype=args.static_kv_dtype,
610-
static_attention_dtype=args.static_attention_dtype,
611602
)
612603

613604
model_name = args.model.rstrip("/")

auto_round/compressors/base.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def __init__(
237237
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
238238
self.momentum = kwargs.pop("momentum", 0.0)
239239
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
240-
static_attention_dtype = kwargs.pop("static_attention_dtype", None)
241240
model_dtype = kwargs.pop("model_dtype", None)
242241
device = kwargs.pop("device", None)
243242
if envs.AR_USE_MODELSCOPE:
@@ -357,11 +356,6 @@ def __init__(
357356
if self.static_kv_dtype is not None:
358357
logger.warning("The static kv is experimental and currently has limited support.")
359358

360-
# Attention static dtype
361-
self.static_attention_dtype = static_attention_dtype
362-
if self.static_attention_dtype is not None:
363-
logger.warning("The static attention dtype is experimental and currently has limited support.")
364-
365359
self._set_amp_dtype()
366360
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device
367361
if self.act_bits <= 8 and self.amp_dtype == torch.float16:
@@ -1010,12 +1004,7 @@ def quantize_and_save(
10101004
kwargs.pop("inplace", None)
10111005

10121006
# Perform model quantization
1013-
if self.static_attention_dtype is not None:
1014-
from auto_round.experimental.attention import attention_quant_ctx
1015-
1016-
with attention_quant_ctx(self.model, static_attention_dtype=self.static_attention_dtype):
1017-
model, _ = self.quantize()
1018-
elif self.static_kv_dtype is not None:
1007+
if self.static_kv_dtype is not None:
10191008
from auto_round.experimental.kv_cache import kvcache_quant_context
10201009

10211010
with kvcache_quant_context(self.model, static_kv_dtype=self.static_kv_dtype):

auto_round/experimental/attention.py

Lines changed: 0 additions & 190 deletions
This file was deleted.

auto_round/experimental/kv_cache.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,6 @@
2424
import torch
2525
from transformers.cache_utils import DynamicCache
2626

27-
from auto_round.experimental.utils import (
28-
is_attention_module,
29-
normalize_static_kv_dtype,
30-
per_tensor_fp8_qdq,
31-
update_parameter_data,
32-
)
3327
from auto_round.utils import logger
3428

3529
__all__ = [
@@ -87,6 +81,13 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
8781
return lst
8882

8983

84+
def fp8_per_tensor_qdq(tensor):
85+
from auto_round.data_type.fp8 import quant_fp8_sym
86+
87+
qdq_tensor, scale, _ = quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=0, v=0)
88+
return qdq_tensor, scale
89+
90+
9091
class QuantizedKVParameterCache(DynamicCache):
9192
"""
9293
Quantized KV cache used in the forward call based on HF's dynamic cache.
@@ -172,8 +173,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_
172173
assert kv_type == KVCacheScaleType.VALUE
173174
scales = self.v_scales
174175

175-
qdq_tensor, scale = per_tensor_fp8_qdq(tensor)
176-
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0))
176+
qdq_tensor, scale = fp8_per_tensor_qdq(tensor)
177+
_pad_and_append_at_idx_(scales, layer_idx, scale)
177178
return qdq_tensor
178179

179180

@@ -191,9 +192,13 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4
191192
quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype)
192193
setattr(module, "kv_cache", quantized_kv_cache)
193194
logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
194-
init_scale = torch.tensor([0.0], device=next(module.parameters()).device)
195-
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value)
196-
update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value)
195+
196+
197+
def is_attention_module(module: torch.nn.Module):
198+
# FIXME: Handle this better.
199+
return "attention" in module.__class__.__name__.lower() and (
200+
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj")
201+
)
197202

198203

199204
def calibrate_kv_cache_input_hook(
@@ -204,6 +209,7 @@ def calibrate_kv_cache_input_hook(
204209
kv_cache quantization. Will update the passed in
205210
kv_cache to singleton QuantizedKVParameterCache.
206211
"""
212+
logger.debug(f"calibrate kv_cache input hook for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}")
207213
kv_cache = getattr(module, "kv_cache")
208214
# Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`.
209215
# https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280
@@ -215,14 +221,33 @@ def calibrate_kv_cache_input_hook(
215221
return args, kwargs
216222

217223

224+
def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: str):
225+
"""
226+
Update the data of a parameter in a module.
227+
If the parameter does not exist, it will be created.
228+
"""
229+
if hasattr(module, name):
230+
param = getattr(module, name)
231+
if isinstance(param, torch.nn.Parameter):
232+
param.data = new_val
233+
else:
234+
module.register_parameter(name, torch.nn.Parameter(new_val))
235+
else:
236+
logger.warning(
237+
"Parameter %s not found in module %s, creating new parameter."
238+
% (name, module.__class__.__name__ + str(getattr(module, "layer_idx", "")))
239+
)
240+
module.register_parameter(name, torch.nn.Parameter(new_val))
241+
242+
218243
def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: torch.Tensor):
219244
"""
220245
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
221246
"""
222-
# logger.debug(
223-
# "Calibrate kv_cache output hook for %s %s"
224-
# % (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
225-
# )
247+
logger.debug(
248+
"Calibrate kv_cache output hook for %s %s"
249+
% (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
250+
)
226251
kv_cache = getattr(module, "kv_cache")
227252
k_scale = kv_cache.k_scales[module.layer_idx]
228253
v_scale = kv_cache.v_scales[module.layer_idx]
@@ -236,6 +261,28 @@ def prep_attention_module_for_calibration(module: torch.nn.Module):
236261
module.register_forward_hook(calibrate_kv_cache_output_hook)
237262

238263

264+
def normalize_static_kv_dtype(static_kv_dtype: Union[str, torch.dtype]) -> torch.dtype:
265+
valid_dtype_name_lst = ["float16", "bfloat16", "fp8", "float32", "float"]
266+
valid_torch_dtype = {
267+
"float16": torch.float16,
268+
"bfloat16": torch.bfloat16,
269+
"fp8": torch.float8_e4m3fn,
270+
"float8_e4m3fn": torch.float8_e4m3fn,
271+
"float32": torch.float32,
272+
"float": torch.float32, # Alias for float32
273+
}
274+
if static_kv_dtype in valid_dtype_name_lst:
275+
new_dtype = valid_torch_dtype[static_kv_dtype]
276+
elif static_kv_dtype in valid_torch_dtype.values():
277+
new_dtype = static_kv_dtype
278+
else:
279+
raise ValueError(
280+
f"Invalid static kv dtype: {static_kv_dtype}. "
281+
f"Valid options are: {', '.join(valid_dtype_name_lst + list(valid_torch_dtype.values()))}."
282+
)
283+
return new_dtype
284+
285+
239286
@contextlib.contextmanager
240287
def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn):
241288
"""Context manager for FP8 KV cache quantization operations."""

auto_round/experimental/qmodules/fp8_static.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def qdq_input(self, bf16_input: torch.Tensor):
115115

116116
@torch.no_grad()
117117
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
118-
original_dtype = bf16_input.dtype
118+
119119
qdq_input = self.qdq_input(bf16_input)
120120
qdq_weight = self.dequant_weight_online()
121-
out = torch.nn.functional.linear(qdq_input.to(original_dtype), qdq_weight.to(original_dtype), self.bias)
121+
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
122122
return out

0 commit comments

Comments
 (0)