diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index a562bdf13fa..69980990cfd 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -56,6 +56,7 @@ get_quant_weight_transform, ) from .source_transformation.quantized_kv_cache import ( + replace_kv_cache_with_custom_kv_cache, replace_kv_cache_with_quantized_kv_cache, ) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm @@ -1058,6 +1059,7 @@ def _get_source_transforms( # noqa transforms.append(materialze_broadcast_of_rope_freq_cis) if args.use_sdpa_with_kv_cache: + transforms.append(replace_kv_cache_with_custom_kv_cache) transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index a0c8c2fd93b..d8ac99656f1 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -6,6 +6,7 @@ import logging from enum import Enum +from typing import Tuple import torch import torch.nn as nn @@ -44,7 +45,6 @@ def __init__( QuantizedCacheType.AffineSymmetric, QuantizedCacheType.AffineAsymmetric, ): - raise ValueError( f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" ) @@ -81,10 +81,11 @@ def __init__( ) def _quantize(self, value): - scales, zero_points = ( - torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( - value, self.quantized_cache_dtype - ) + ( + scales, + zero_points, + ) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( + value, self.quantized_cache_dtype ) quantized_value = torch.ops.quantized_decomposed.quantize_per_token( value, @@ -262,3 +263,71 @@ def replace_kv_cache_with_quantized_kv_cache(module): else: replace_kv_cache_with_quantized_kv_cache(child) return module + + +class CustomKVCache(nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + dtype=torch.float32, + ): + super().__init__() + self.max_seq_length = max_seq_length + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [S], k_val: [B, S, H, D] + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) + _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) + return self.k_cache, self.v_cache + + +def replace_kv_cache_with_custom_kv_cache(module): + r""" + Replace KVCache with CustomKVCache. This modifies the model in place. + At the moment custom kv cache only supports cache with shape + [B, S, H, D] as opposed to [B, H, S, D] + This is because the custom op treats second dim as sequence dim. + Future work: support [B, H, S, D] + """ + logging.warning( + "Replacing KVCache with CustomKVCache. This modifies the model in place." + ) + for name, child in module.named_children(): + if isinstance(child, KVCache): + cache_shape = child.k_cache.shape + cache_dtype = child.k_cache.dtype + assert ( + child.is_transposed is False + ), "CustomKVCache does not support transposed cache" + max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + setattr( + module, + name, + CustomKVCache( + max_batch_size, + max_seq_length, + n_heads, + head_dim, + dtype=cache_dtype, + ), + ) + else: + replace_kv_cache_with_custom_kv_cache(child) + return module diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 59bfbe6f951..4d4b3bf7f56 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -56,33 +56,16 @@ def forward( k_cache = self.kv_cache.k_cache v_cache = self.kv_cache.v_cache - if hasattr(self.kv_cache, "quantized_cache_dtype"): - # updated quantize cache, scale and zero points - # returns dequantized kv cache - # Not most optimal. Optimizations to follow next - k_cache, v_cache = self.kv_cache.update(input_pos, k, v) - output = torch.ops.llama.custom_sdpa( - q, - k_cache, - v_cache, - input_pos[0].item(), - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal - ) - else: - output = torch.ops.llama.sdpa_with_kv_cache( - q, - k, - v, - k_cache, - v_cache, - input_pos[0].item(), - seqlen, - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal - ) + k_cache, v_cache = self.kv_cache.update(input_pos, k, v) + output = torch.ops.llama.custom_sdpa( + q, + k_cache, + v_cache, + input_pos[0].item(), + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -106,7 +89,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: class SDPASimple(torch.nn.Module): - def __init__( self, kv_cache: KVCache, @@ -166,7 +148,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class SDPAFlex(torch.nn.Module): - def __init__( self, kv_cache: KVCache, diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 21952d8c211..57c36dabf9b 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -11,6 +11,7 @@ from executorch.examples.models.llama.llama_transformer import KVCache from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( + CustomKVCache, QuantizedCacheType, QuantizedKVCache, ) @@ -19,7 +20,6 @@ class SDPAWithQuantizedKVCacheTest(unittest.TestCase): - def _init_cache(self): self.kv_cache = KVCache( self.max_batch_size, @@ -33,6 +33,19 @@ def _init_cache(self): self.quantized_kv_cache = QuantizedKVCache.from_float( self.kv_cache, QuantizedCacheType.AffineAsymmetric ) + # Need this because first test actually has seq_len > 1 + # and vanilla kvcache cannot handle seq_len > 1, due to + # how input_pos encoding works in the current stack. + # This needs fixing by making sure rest of the stack including + # custom ops or other backends can work with input_pos + # as a sequence of token positions + self.custom_kv_cache = CustomKVCache( + self.max_batch_size, + self.max_seq_len, + self.n_kv_heads, + self.head_dim, + dtype=self.dtype, + ) def _init_kv(self): kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) @@ -59,7 +72,7 @@ def test_simple(self, is_dynamic_shape=False): self.seq_len = 3 self._init_cache() q, k, v = self._init_kv() - self.float_sdpa = SDPACustom(self.kv_cache, self.dim) + self.float_sdpa = SDPACustom(self.custom_kv_cache, self.dim) self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)