Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
return_float_values: bool = True,
):
super().__init__()
if cache_type not in (
Expand All @@ -57,7 +58,7 @@ def __init__(
self.use_custom_update_cache_op = use_custom_update_cache_op
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
self.return_float_values = True
self.return_float_values = return_float_values
self.max_context_length = max_context_length
cache_shape = (max_batch_size, max_context_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_context_length, n_heads, 1)
Expand Down Expand Up @@ -400,6 +401,7 @@ def __init__(
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
use_custom_update_cache_op: bool = False,
return_float_values: bool = True,
):
# Look at attention.py for explanation on why max_context_length * 2
super().__init__(
Expand All @@ -409,6 +411,7 @@ def __init__(
head_dim,
cache_type,
use_custom_update_cache_op,
return_float_values,
)
self.cache_positions_manager = CachePositionsManager(self.max_context_length)
self.is_ring_buffer = True
Expand Down Expand Up @@ -459,6 +462,7 @@ def from_quantized_kv_cache(
head_dim,
kv_cache.cache_type,
kv_cache.use_custom_update_cache_op,
kv_cache.return_float_values,
)


Expand Down Expand Up @@ -583,4 +587,12 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
# it is not doing causal attention
if "SDPACustom" in attention.SDPA.__class__.__name__:
attention.SDPA.use_attention_mask = True
# QuantizedSDPA has to store kv_cache in order to obtrain
# scales and zero points for k and v cache.
# So if we replcaed attention module's quantized kv cache with
# QuantizedRingKVCache then we also have to replace attention's
# SDPA module kv_cache so that it refers to the same kv_cache
if "QuantizedSDPA" in attention.SDPA.__class__.__name__:
attention.SDPA.use_attention_mask = True
attention.SDPA.kv_cache = attention.kv_cache
return module
Loading