2424import torch
2525from transformers .cache_utils import DynamicCache
2626
27- from auto_round .experimental .utils import (
28- is_attention_module ,
29- normalize_static_kv_dtype ,
30- per_tensor_fp8_qdq ,
31- update_parameter_data ,
32- )
3327from auto_round .utils import logger
3428
3529__all__ = [
@@ -87,6 +81,13 @@ def _pad_and_append_at_idx_(lst: List, idx: int, val: Any) -> list:
8781 return lst
8882
8983
84+ def fp8_per_tensor_qdq (tensor ):
85+ from auto_round .data_type .fp8 import quant_fp8_sym
86+
87+ qdq_tensor , scale , _ = quant_fp8_sym (tensor , max_scale = 1.0 , tensor_max = None , group_size = 0 , v = 0 )
88+ return qdq_tensor , scale
89+
90+
9091class QuantizedKVParameterCache (DynamicCache ):
9192 """
9293 Quantized KV cache used in the forward call based on HF's dynamic cache.
@@ -172,8 +173,8 @@ def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_
172173 assert kv_type == KVCacheScaleType .VALUE
173174 scales = self .v_scales
174175
175- qdq_tensor , scale = per_tensor_fp8_qdq (tensor )
176- _pad_and_append_at_idx_ (scales , layer_idx , scale . squeeze ( 0 ) )
176+ qdq_tensor , scale = fp8_per_tensor_qdq (tensor )
177+ _pad_and_append_at_idx_ (scales , layer_idx , scale )
177178 return qdq_tensor
178179
179180
@@ -191,9 +192,13 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4
191192 quantized_kv_cache = QuantizedKVParameterCache (dtype = dtype )
192193 setattr (module , "kv_cache" , quantized_kv_cache )
193194 logger .debug (f"Initialized quantized kv_cache for { module .__class__ .__name__ } { getattr (module , 'layer_idx' , None )} " )
194- init_scale = torch .tensor ([0.0 ], device = next (module .parameters ()).device )
195- update_parameter_data (module , init_scale .clone (), KVCacheScaleType .KEY .value )
196- update_parameter_data (module , init_scale .clone (), KVCacheScaleType .VALUE .value )
195+
196+
197+ def is_attention_module (module : torch .nn .Module ):
198+ # FIXME: Handle this better.
199+ return "attention" in module .__class__ .__name__ .lower () and (
200+ hasattr (module , "k_proj" ) or hasattr (module , "v_proj" ) or hasattr (module , "qkv_proj" )
201+ )
197202
198203
199204def calibrate_kv_cache_input_hook (
@@ -204,6 +209,7 @@ def calibrate_kv_cache_input_hook(
204209 kv_cache quantization. Will update the passed in
205210 kv_cache to singleton QuantizedKVParameterCache.
206211 """
212+ logger .debug (f"calibrate kv_cache input hook for { module .__class__ .__name__ } { getattr (module , 'layer_idx' , None )} " )
207213 kv_cache = getattr (module , "kv_cache" )
208214 # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`.
209215 # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280
@@ -215,14 +221,33 @@ def calibrate_kv_cache_input_hook(
215221 return args , kwargs
216222
217223
224+ def update_parameter_data (module : torch .nn .Module , new_val : torch .Tensor , name : str ):
225+ """
226+ Update the data of a parameter in a module.
227+ If the parameter does not exist, it will be created.
228+ """
229+ if hasattr (module , name ):
230+ param = getattr (module , name )
231+ if isinstance (param , torch .nn .Parameter ):
232+ param .data = new_val
233+ else :
234+ module .register_parameter (name , torch .nn .Parameter (new_val ))
235+ else :
236+ logger .warning (
237+ "Parameter %s not found in module %s, creating new parameter."
238+ % (name , module .__class__ .__name__ + str (getattr (module , "layer_idx" , "" )))
239+ )
240+ module .register_parameter (name , torch .nn .Parameter (new_val ))
241+
242+
218243def calibrate_kv_cache_output_hook (module : torch .nn .Module , _args : Any , _output : torch .Tensor ):
219244 """
220245 Hook to update k_scale and v_scale parameters when running kv_cache quantization.
221246 """
222- # logger.debug(
223- # "Calibrate kv_cache output hook for %s %s"
224- # % (module.__class__.__name__, str(getattr(module, "layer_idx", None)))
225- # )
247+ logger .debug (
248+ "Calibrate kv_cache output hook for %s %s"
249+ % (module .__class__ .__name__ , str (getattr (module , "layer_idx" , None )))
250+ )
226251 kv_cache = getattr (module , "kv_cache" )
227252 k_scale = kv_cache .k_scales [module .layer_idx ]
228253 v_scale = kv_cache .v_scales [module .layer_idx ]
@@ -236,6 +261,28 @@ def prep_attention_module_for_calibration(module: torch.nn.Module):
236261 module .register_forward_hook (calibrate_kv_cache_output_hook )
237262
238263
264+ def normalize_static_kv_dtype (static_kv_dtype : Union [str , torch .dtype ]) -> torch .dtype :
265+ valid_dtype_name_lst = ["float16" , "bfloat16" , "fp8" , "float32" , "float" ]
266+ valid_torch_dtype = {
267+ "float16" : torch .float16 ,
268+ "bfloat16" : torch .bfloat16 ,
269+ "fp8" : torch .float8_e4m3fn ,
270+ "float8_e4m3fn" : torch .float8_e4m3fn ,
271+ "float32" : torch .float32 ,
272+ "float" : torch .float32 , # Alias for float32
273+ }
274+ if static_kv_dtype in valid_dtype_name_lst :
275+ new_dtype = valid_torch_dtype [static_kv_dtype ]
276+ elif static_kv_dtype in valid_torch_dtype .values ():
277+ new_dtype = static_kv_dtype
278+ else :
279+ raise ValueError (
280+ f"Invalid static kv dtype: { static_kv_dtype } . "
281+ f"Valid options are: { ', ' .join (valid_dtype_name_lst + list (valid_torch_dtype .values ()))} ."
282+ )
283+ return new_dtype
284+
285+
239286@contextlib .contextmanager
240287def kvcache_quant_context (model : torch .nn .Module , static_kv_dtype = torch .float8_e4m3fn ):
241288 """Context manager for FP8 KV cache quantization operations."""
0 commit comments