diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 25ec207d0e0..0fbdd1936ef 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -43,6 +43,7 @@ def __init__( head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, use_custom_update_cache_op: bool = False, + return_float_values: bool = True, ): super().__init__() if cache_type not in ( @@ -57,7 +58,7 @@ def __init__( self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 - self.return_float_values = True + self.return_float_values = return_float_values self.max_context_length = max_context_length cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) scale_shape = (max_batch_size, max_context_length, n_heads, 1) @@ -400,6 +401,7 @@ def __init__( head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, use_custom_update_cache_op: bool = False, + return_float_values: bool = True, ): # Look at attention.py for explanation on why max_context_length * 2 super().__init__( @@ -409,6 +411,7 @@ def __init__( head_dim, cache_type, use_custom_update_cache_op, + return_float_values, ) self.cache_positions_manager = CachePositionsManager(self.max_context_length) self.is_ring_buffer = True @@ -459,6 +462,7 @@ def from_quantized_kv_cache( head_dim, kv_cache.cache_type, kv_cache.use_custom_update_cache_op, + kv_cache.return_float_values, ) @@ -583,4 +587,12 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): # it is not doing causal attention if "SDPACustom" in attention.SDPA.__class__.__name__: attention.SDPA.use_attention_mask = True + # QuantizedSDPA has to store kv_cache in order to obtrain + # scales and zero points for k and v cache. + # So if we replcaed attention module's quantized kv cache with + # QuantizedRingKVCache then we also have to replace attention's + # SDPA module kv_cache so that it refers to the same kv_cache + if "QuantizedSDPA" in attention.SDPA.__class__.__name__: + attention.SDPA.use_attention_mask = True + attention.SDPA.kv_cache = attention.kv_cache return module