From b90b99e13d0d3d898a979ff3a2503548b3714943 Mon Sep 17 00:00:00 2001 From: "Di Xu (SWE)" Date: Tue, 21 Jan 2025 09:47:50 -0800 Subject: [PATCH] Create a forked KV IO transformer for exporting coreML Llama (#7778) Summary: Create a forked KV IO transformer for exporting coreML Llama - As discussed in the group chat "Design: KV cache IO on ANE: llama_transformer vs. static_llama", we agreed that forking a KV IO version llama_transformer is best for code quality and coreML/ANE development purposes Differential Revision: D68424012 --- .../kv_io_llama/kv_io_llama_transformer.py | 582 ++++++++++++++++++ examples/models/kv_io_llama/targets.bzl | 3 + examples/models/llama/TARGETS | 16 + 3 files changed, 601 insertions(+) create mode 100644 examples/models/kv_io_llama/kv_io_llama_transformer.py create mode 100644 examples/models/kv_io_llama/targets.bzl diff --git a/examples/models/kv_io_llama/kv_io_llama_transformer.py b/examples/models/kv_io_llama/kv_io_llama_transformer.py new file mode 100644 index 00000000000..2d3894ed9ad --- /dev/null +++ b/examples/models/kv_io_llama/kv_io_llama_transformer.py @@ -0,0 +1,582 @@ +# @lint-ignore-every LICENSELINT +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Llama 2 is licensed under the LLAMA 2 Community License, +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +# Please refer to README.md in the same folder for more information. + +from dataclasses import dataclass +from functools import partial +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +from executorch.examples.models.llama.rope import ( + hf_apply_rotary_emb, + hf_precompute_freqs_cis, + precompute_freqs_cis, + RotaryEmbedding, +) + +from torch import nn + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + dim: int = 2048 + n_layers: int = 16 + n_heads: int = 32 + n_kv_heads: Optional[int] = 8 + vocab_size: int = 128256 + hidden_dim: Optional[int] = 8192 + head_dim: Optional[int] = None # Optional customized head_dim + multiple_of: int = 256 + ffn_dim_multiplier: Optional[float] = 1.5 + norm_eps: float = 1e-5 + max_batch_size: int = 1 + max_seq_len: int = 32 + moe: bool = False # True to enable the MoE (Mixture of Experts) + num_experts: int = 8 # Number of experts + num_activated_experts: int = 2 # Number of experts to activate + use_kv_cache: bool = False # Use key/value cache + use_sdpa_with_kv_cache_op: bool = ( + False # Use custom sdpa op that updates kv cache in-place + ) + # Generate logits for all inputs. When it's True, it would take big memory usage + # at runtime. Enable it only necessary (e.g., use perplexity tools that requires + # logits for all input tokens.) + generate_full_logits: bool = False + enable_dynamic_shape: bool = False # export model with dynamic shape support + # A dictionary mapping from pruned token-id to original token-id + input_prune_map: Optional[Dict[int, int]] = None + # A dictionary mapping from pruned token-id to original token-id + output_prune_map: Optional[Dict[int, int]] = None + use_hf_rope: bool = False # Use HuggingFace's RoPE implementation + rope_theta: Optional[float] = ( + None # The official name to override self.rope_freq_base. + ) + rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. + use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1. + # Additional Model Metadata needed at runtime + rope_scale_factor: int = 8 + bos_idx: int = 1 + eos_idx: int = 3 + bos_count: int = -1 # i.e., a single EOS is used as BOS + eos_count: int = 2 + + quantization_args: Optional[dict] = None + lora_args: Optional[dict] = None + + def __post_init__(self): + if self.n_kv_heads is None: + self.n_kv_heads = self.n_heads + + # rope_theta overrides rope_freq_base since it's the official name. + if self.rope_theta is not None: + self.rope_freq_base = self.rope_theta + + if self.use_sdpa_with_kv_cache_op: + assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" + + if self.hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and also multiple of `args.multiple_of` + multiple_of = self.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + if self.ffn_dim_multiplier is not None: + hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) + self.hidden_dim = find_multiple(hidden_dim, multiple_of) + + if self.head_dim is None: + self.head_dim = self.dim // self.n_heads + + +class Rope(torch.nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + if self.params.use_hf_rope: + self.precompute_freqs_cis = hf_precompute_freqs_cis + else: + self.precompute_freqs_cis = partial( + precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + ) + freqs_cos, freqs_sin = self.precompute_freqs_cis( + self.params.head_dim, + ( + self.params.max_seq_len # Normal llama2. + if self.params.ffn_dim_multiplier is None + else self.params.max_seq_len * 2 # Sharded checkpoint. + ), + self.params.rope_freq_base, + scale_factor=8, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + if self.params.use_hf_rope: + self.apply_rotary_emb = hf_apply_rotary_emb + else: + self.apply_rotary_emb = RotaryEmbedding() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + """ + Get the precomputed frequencies for the given input position and sequence length. + + Args: + input_pos (torch.Tensor): The input position tensor. + seq_len (int): The sequence length. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length. + """ + if self.params.use_kv_cache: + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + + if self.params.enable_dynamic_shape: + # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_seq_len) + # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor + freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) + # pyre-ignore: Incompatible parameter type [6] + freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) + else: + # When not using dynamic shape, use of the .item results in + # symints, due to querying the data from tensor. + # this path avoids that for mps backend, although probably mps backend + # can support dynamic shape? + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + + else: + if input_pos is None: + freqs_cos = self.freqs_cos[:seq_len] + freqs_sin = self.freqs_sin[:seq_len] + else: + freqs_cos = self.freqs_cos[input_pos] + freqs_sin = self.freqs_sin[input_pos] + return freqs_cos, freqs_sin + + +class KVCache(nn.Module): + def __init__( + self, + max_batch_size: int, + max_seq_length: int, + n_heads: int, + head_dim: int, + transpose_cache: bool, + enable_dynamic_shape: bool, + dtype=torch.float32, + ): + super().__init__() + self.max_seq_length = max_seq_length + self.is_transposed = transpose_cache + if transpose_cache: + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + else: + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.transpose_cache = transpose_cache + self.enable_dynamic_shape = enable_dynamic_shape + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache + if self.enable_dynamic_shape: + start_pos = input_pos[0].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_length) + dim_to_slice = 2 if self.transpose_cache else 1 + seq_length = k_val.size(dim_to_slice) + # Replace the entry in the cache for this token + # The following lines are equivalent to: + # cache_k[:bsz, start_pos : start_pos + seqlen] = xk + # cache_v[:bsz, start_pos : start_pos + seqlen] = xv + # when dim_to_slice is 1 + # We use .narrow() here to make the compiler happy + # pyre-ignore: Incompatible parameter type [6] + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) + # pyre-ignore: Incompatible parameter type [6] + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) + + narrowed_k.copy_(k_val) + narrowed_v.copy_(v_val) + return self.k_cache, self.v_cache + else: + k_out = self.k_cache + v_out = self.v_cache + if self.transpose_cache: + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + else: + k_out[:, input_pos] = k_val + v_out[:, input_pos] = v_val + + return k_out, v_out + + +class SDPA(nn.Module): + def __init__( + self, + kv_cache: KVCache, + dim: int, + head_dim: int, + n_rep: int, + max_seq_len: int, + enable_dynamic_shape: bool, + ): + super().__init__() + self.kv_cache = kv_cache + self.dim = dim + self.head_dim = head_dim + self.n_rep = n_rep + self.max_seq_len = max_seq_len + self.enable_dynamic_shape = enable_dynamic_shape + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) + v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) + bsz, + seqlen, + mask: torch.Tensor, + ) -> torch.Tensor: + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + k, v = self.kv_cache.update(input_pos, k, v) + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_seq_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + attn_mask = mask.narrow(0, start_pos, seq_length) + else: + attn_mask = mask[None, None, input_pos] + + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + super().__init__() + self.use_kv_cache = args.use_kv_cache + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + print(f"self.n_heads={self.n_heads}, self.n_kv_heads={self.n_kv_heads}") + assert self.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = self.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.head_dim + self.max_batch_size = args.max_batch_size + self.max_seq_len = args.max_seq_len + self.dim = args.dim + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.layer_id = layer_id + + self.rope = rope + + causal_mask = torch.tril( + torch.ones( + self.max_seq_len, + self.max_seq_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + if self.use_kv_cache: + self.kv_cache = KVCache( + args.max_batch_size, + args.max_seq_len, + self.n_kv_heads, + self.head_dim, + not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v + args.enable_dynamic_shape, + ) + self.SDPA = SDPA( + kv_cache=self.kv_cache, + dim=self.n_local_heads * self.head_dim, + head_dim=self.head_dim, + n_rep=self.n_rep, + max_seq_len=self.max_seq_len, + enable_dynamic_shape=args.enable_dynamic_shape, + ) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + k_cache: Optional[torch.Tensor] = None, # [bs, n_local_kv_heads, seq_len, head_dim] + v_cache: Optional[torch.Tensor] = None, # [bs, n_local_kv_heads, seq_len, head_dim] + cache_pos_mask = None, + ): + bsz, seqlen, _ = x.shape + # QKV + q, k, v = self.wq(x), self.wk(x), self.wv(x) + # We need view_copy elimination + q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + # RoPE relative positional embeddings + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + + if self.use_kv_cache: + assert input_pos is not None + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) + return self.wo(output) + + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + assert hasattr(self, "mask") + + new_k = k + new_v = v + + if input_pos is None: + mask = self.mask[:seqlen, :seqlen] + k_out = k + v_out = v + else: + assert k_cache is not None + assert v_cache is not None + + mask = self.mask[None, None, input_pos] + + + + k_update = cache_pos_mask * k + k_out = k_cache + k_update + + v_update = cache_pos_mask * v + v_out = v_cache + v_update + + + # grouped multiquery attention: expand out keys and values + k = k_out.repeat_interleave(self.n_rep, dim=1) + v = v_out.repeat_interleave(self.n_rep, dim=1) + + output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + output = self.wo(output) + + return output, k_out ,v_out, new_k, new_v + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.hidden_dim is not None + hidden_dim: int = args.hidden_dim + self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): + super().__init__() + self.use_kv_cache = args.use_kv_cache + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.head_dim + self.attention = Attention(args, layer_id, rope) + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward(self, x, freqs_cos, freqs_sin, input_pos=None, k_cache=None, v_cache=None, cache_pos_mask=None,): # x: 1xN + norm_emb = self.attention_norm(x) + h, k, v, new_k, new_v = self.attention.forward( + norm_emb, freqs_cos, freqs_sin, input_pos, k_cache, v_cache, cache_pos_mask + ) + + h = x + h + out = h + self.feed_forward(self.ffn_norm(h)) + + return out, k, v + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.rope = Rope(params) + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params, self.rope)) + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + self.use_kv_cache = params.use_kv_cache + self.generate_full_logits = params.generate_full_logits + self.max_seq_len = params.max_seq_len + self.input_prune_map = params.input_prune_map + self.output_prune_map = params.output_prune_map + + def forward( + self, + tokens: Optional[torch.LongTensor] = None, # tokens + input_pos: Optional[ + torch.LongTensor + ] = None, # Scalar tensor indicating size of window of the caches + k_cache: Optional[torch.FloatTensor] = None, + v_cache: Optional[torch.FloatTensor] = None, + cache_pos_mask = None, + h: Optional[torch.FloatTensor] = None, # embeddings + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # prefill ONLY k_cache and v_cache can be None + print_k_cache = "None" if k_cache is None else "Val" + print_v_cache = "None" if v_cache is None else "Val" + print(f"k_cache: {print_k_cache}; print_v_cache: {print_v_cache}") + + if (tokens is None) ^ (h is not None): + raise ValueError( + "You cannot specify both tokens and h at the same time, and must specify either one" + ) + if tokens is not None and h is None: + h = self.tok_embeddings(tokens) + seqlen = h.shape[1] + freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen) + if input_pos is not None: + assert k_cache is not None + assert v_cache is not None + k_cache_list = torch.unbind(k_cache, dim=-1) + v_cache_list = torch.unbind(v_cache, dim=-1) + k_out = [] + v_out = [] + for i, layer in enumerate(self.layers): + if input_pos is not None: + k_input = k_cache_list[i] + v_input = v_cache_list[i] + else: + k_input = None + v_input = None + h, k, v = layer( + h, + freqs_cos, + freqs_sin, + input_pos, + k_input, + v_input, + cache_pos_mask, + ) + k_out.append(k) + v_out.append(v) + + if not self.generate_full_logits or input_pos is not None: + # Only the last logit is used for the new generated token + h = h[:, -1, :] + + h = self.norm(h) + + logits = self.output(h) + + return logits, torch.stack(k_out, dim=-1), torch.stack(v_out, dim=-1) diff --git a/examples/models/kv_io_llama/targets.bzl b/examples/models/kv_io_llama/targets.bzl new file mode 100644 index 00000000000..f76c3a540f6 --- /dev/null +++ b/examples/models/kv_io_llama/targets.bzl @@ -0,0 +1,3 @@ + +def define_common_targets(): + pass diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 2037e229345..c814bdc1d41 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -9,6 +9,22 @@ oncall("executorch") define_common_targets() +runtime.python_library( + name = "rope", + srcs = [ + "rope.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama", + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + ], +) + runtime.python_library( name = "llama_transformer", srcs = [