diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7ebdf95418d..e4146224c3d 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -41,6 +41,7 @@ get_qnn_quantizer, get_vulkan_quantizer, ) +from executorch.extension.llm.modules import replace_mha_with_inference_mha from executorch.util.activation_memory_profiler import generate_memory_trace from ..model_factory import EagerModelFactory @@ -536,7 +537,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: else: dtype_override = None - return ( + model_manager = ( _load_llama_model( args.model, checkpoint=checkpoint_path, @@ -563,6 +564,15 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: .set_output_dir(output_dir_path) .source_transform(_get_source_transforms(args.model, dtype_override, args)) ) + if args.model in TORCHTUNE_DEFINED_MODELS: + if args.use_kv_cache: + print("Setting up the KV cache...") + model_manager.model.setup_caches( + batch_size=1, + dtype=dtype_override.to_torch_dtype(), + decoder_max_seq_len=args.max_seq_length, + ) + return model_manager def get_quantizer_and_quant_params(args): @@ -974,6 +984,10 @@ def _get_source_transforms( # noqa ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: transforms = [] + is_torchtune = modelname in TORCHTUNE_DEFINED_MODELS + if is_torchtune: + transforms.append(replace_mha_with_inference_mha) + if args.use_spin_quant: if args.use_spin_quant == "cuda": from .source_transformation.spin_quant import ( @@ -1075,4 +1089,6 @@ def _get_source_transforms( # noqa if args.vulkan: transforms.append(replace_with_vulkan_rotary_emb) + print(f"Source transformations: {[t.__name__ for t in transforms]}") + return transforms diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 6d92a45e800..816d3420bab 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -39,6 +39,11 @@ def __init__( enable_dynamic_shape=False, ): super().__init__() + self.max_batch_size = max_batch_size + self.max_seq_length = max_seq_length + self.n_heads = n_heads + self.head_dim = head_dim + self.cache_type = cache_type if cache_type not in ( QuantizedCacheType.AffineSymmetric, QuantizedCacheType.AffineAsymmetric, @@ -65,6 +70,9 @@ def __init__( self.register_buffer( "v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) + self.register_buffer( + "cache_pos", torch.arange(0, max_seq_length), persistent=False + ) self.register_buffer( "k_cache_scales", torch.ones(scale_shape, dtype=torch.float64) ) @@ -95,7 +103,7 @@ def _quantize(self, value): ) return quantized_value, scales, zero_points - def update(self, input_pos, k_val, v_val): + def update(self, k_val, v_val): # quantize current k_val and store it in the cache quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) @@ -110,7 +118,7 @@ def update(self, input_pos, k_val, v_val): # for lowering pains of backends that work better # with index_put op. if self.enable_dynamic_shape: - start_pos = input_pos[0].item() + start_pos = self.cache_pos[0].item() torch._check_is_size(start_pos) dim_to_slice = 2 if self.is_transposed else 1 torch._check(start_pos < self.k_cache.size(dim_to_slice)) @@ -136,12 +144,12 @@ def update(self, input_pos, k_val, v_val): narrowed_v_scales.copy_(v_scales) narrowed_v_zp.copy_(v_zero_points) else: - self.k_cache[:, :, input_pos] = quantized_k_val - self.k_cache_scales[:, :, input_pos] = k_scales - self.k_cache_zero_points[:, :, input_pos] = k_zero_points - self.v_cache[:, :, input_pos] = quantized_v_val - self.v_cache_scales[:, :, input_pos] = v_scales - self.v_cache_zero_points[:, :, input_pos] = v_zero_points + self.k_cache[:, :, self.cache_pos] = quantized_k_val + self.k_cache_scales[:, :, self.cache_pos] = k_scales + self.k_cache_zero_points[:, :, self.cache_pos] = k_zero_points + self.v_cache[:, :, self.cache_pos] = quantized_v_val + self.v_cache_scales[:, :, self.cache_pos] = v_scales + self.v_cache_zero_points[:, :, self.cache_pos] = v_zero_points else: # Right now using custom ops on this path. # In future we can update custom op to handle transposed cache @@ -150,7 +158,7 @@ def update(self, input_pos, k_val, v_val): # backends such as QNN want to use quantized cache, with dynamic shape, # instead of quantizing on their own. # But until this opting for code simplicity - start_pos = input_pos[0].item() + start_pos = self.cache_pos[0].item() _ = torch.ops.llama.update_quantized_cache( quantized_k_val, self.k_cache, start_pos ) @@ -207,6 +215,31 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType): kv_cache.enable_dynamic_shape, ) + def clone(self) -> "QuantizedKVCache": + """Create a clone of the KVCache.""" + if self.is_transposed: + num_kv_heads = self.k_cache.shape[1] + else: + num_kv_heads = self.k_cache.shape[2] + clone = QuantizedKVCache( + max_batch_size=self.max_batch_size, + max_seq_length=self.max_seq_length, + n_heads=num_kv_heads, + head_dim=self.k_cache.shape[3], + cache_type=self.cache_type, + tranposed=self.is_transposed, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + clone.k_cache.copy_(self.k_cache) + clone.v_cache.copy_(self.v_cache) + clone.cache_pos.copy_(self.cache_pos) + clone.k_cache_scales.copy_(self.k_cache_scales) + clone.v_cache_scales.copy_(self.v_cache_scales) + if clone.cache_type == QuantizedCacheType.AffineAsymmetric: + clone.k_cache_zero_points.copy_(self.k_cache_zero_points) + clone.v_cache_zero_points.copy_(self.v_cache_zero_points) + return clone + def replace_kv_cache_with_quantized_kv_cache(module): logging.warning( diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index f8362648f32..ac169ae7678 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -23,23 +23,20 @@ class SDPACustom(torch.nn.Module): def __init__( self, kv_cache: Union[KVCache, QuantizedKVCache], - dim: int, ): super().__init__() # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): - self.kv_cache = kv_cache.to(torch.float) - else: - assert ( - kv_cache.cache_fp_type == torch.float32 - ), "Only float32 is supported for custom SDPA" - self.dim = dim + # if not isinstance(kv_cache, QuantizedKVCache): + # self.kv_cache = kv_cache.to(torch.float) + # else: + # assert ( + # kv_cache.cache_fp_type == torch.float32 + # ), "Only float32 is supported for custom SDPA" def forward( self, - input_pos: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -60,12 +57,12 @@ def forward( # updated quantize cache, scale and zero points # returns dequantized kv cache # Not most optimal. Optimizations to follow next - k_cache, v_cache = self.kv_cache.update(input_pos, k, v) + k_cache, v_cache = self.kv_cache.update(self.kv_cache.cache_pos, k, v) output = torch.ops.llama.custom_sdpa( q, k_cache, v_cache, - input_pos[0].item(), + self.kv_cache.cache_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code True, # is_causal @@ -77,13 +74,13 @@ def forward( v, k_cache, v_cache, - input_pos[0].item(), + self.kv_cache.cache_pos[0].item(), seqlen, None, # Attention mask 0, # dropout probability. Ignored by the code True, # is_causal ) - return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) + return output.view(bsz, seqlen, -1).to(dtype=input_dtype) def _replace_sdpa_with_custom_op(module: torch.nn.Module): @@ -106,7 +103,6 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: class SDPASimple(torch.nn.Module): - def __init__( self, kv_cache: KVCache, @@ -122,7 +118,6 @@ def __init__( def forward( self, - input_pos: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -134,8 +129,8 @@ def forward( k = k.transpose(1, 2) v = v.transpose(1, 2) - k, v = self.kv_cache.update(input_pos, k, v) - attn_mask = mask[None, None, input_pos] + k, v = self.kv_cache.update(k, v) + attn_mask = mask[None, None, self.kv_cache.cache_pos] k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) diff --git a/examples/models/llama3_2_vision/text_decoder/model.py b/examples/models/llama3_2_vision/text_decoder/model.py index 8c3943cbcbb..c1b383fa173 100644 --- a/examples/models/llama3_2_vision/text_decoder/model.py +++ b/examples/models/llama3_2_vision/text_decoder/model.py @@ -142,13 +142,13 @@ def __init__(self, **kwargs): self.model_ = prune_output_vocab(self.model_, output_prune_map) - if self.use_kv_cache: - print("Setting up KV cache on the model...") - self.model_.setup_caches( - batch_size=1, - dtype=self.dtype, - decoder_max_seq_len=self.max_seq_len, - ) + # if self.use_kv_cache: + # print("Setting up KV cache on the model...") + # self.model_.setup_caches( + # batch_size=1, + # dtype=self.dtype, + # decoder_max_seq_len=self.max_seq_len, + # ) def get_eager_model(self) -> torch.nn.Module: if self.dtype: diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index eee4aacf44d..b34d2eb264f 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -9,6 +9,13 @@ import torch import torchtune.modules.attention as TorchTuneAttention +from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( + QuantizedKVCache, +) +from executorch.examples.models.llama.source_transformation.sdpa import ( + SDPACustom, + SDPASimple, +) from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention @@ -145,16 +152,27 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() - self._sdpa = SDPA( - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - head_dim=self.head_dim, - attn_dropout=self.attn_dropout if self.training else 0.0, - is_causal=self.is_causal, - attention_fn=self._attention_call, + # self._sdpa = SDPA( + # num_kv_heads=self.num_kv_heads, + # num_heads=self.num_heads, + # head_dim=self.head_dim, + # attn_dropout=self.attn_dropout if self.training else 0.0, + # is_causal=self.is_causal, + # attention_fn=self._attention_call, + # kv_cache=self.kv_cache, + # ) + + self._sdpa = SDPACustom( kv_cache=self.kv_cache, ) + # self._sdpa = SDPASimple( + # kv_cache=self.kv_cache, + # dim=self.embed_dim, + # head_dim=self.head_dim, + # n_rep=self.num_heads // self.num_kv_heads + # ) + # this flag indicates whether to update the kv-cache during forward # passes. when disabled, we can have the cache setup but still # perform normal forward passes @@ -177,13 +195,20 @@ def setup_cache( "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." ) else: - self.kv_cache = InferenceKVCache( - batch_size=batch_size, - max_seq_len=max_seq_len, - num_kv_heads=self.num_kv_heads, + # self.kv_cache = InferenceKVCache( + # batch_size=batch_size, + # max_seq_len=max_seq_len, + # num_kv_heads=self.num_kv_heads, + # head_dim=self.head_dim, + # dtype=dtype, + # transpose_cache=False, + # ) + self.kv_cache = QuantizedKVCache( + max_batch_size=batch_size, + max_seq_length=max_seq_len, + n_heads=self.num_kv_heads, head_dim=self.head_dim, - dtype=dtype, - transpose_cache=False, + # dtype needs to be float32 atm, ) self._sdpa.kv_cache = self.kv_cache self.cache_enabled = True