diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index 324f4aa1f2e..a553fcc0d8b 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -12,6 +12,7 @@ import torch import torch.nn.functional as F +from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import ( hf_apply_rotary_emb, @@ -64,6 +65,8 @@ class ModelArgs: use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1. # Additional Model Metadata needed at runtime rope_scale_factor: int = 8 + high_freq_factor: int = 4 + bos_idx: int = 1 eos_idx: int = 3 bos_count: int = -1 # i.e., a single EOS is used as BOS @@ -74,6 +77,12 @@ class ModelArgs: use_cache_list: bool = True + use_kv_cache: bool = False + enable_dynamic_shape: bool = False + + use_qk_norm: bool = False + qk_norm_before_rope: bool = False + def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads @@ -96,7 +105,7 @@ def __post_init__(self): self.head_dim = self.dim // self.n_heads -class RMSNorm(torch.nn.Module): +class CoreMLRMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """ Initialize the RMSNorm normalization layer. @@ -160,10 +169,16 @@ def __init__(self, params: ModelArgs): super().__init__() self.params = params if self.params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis + self.precompute_freqs_cis = partial( + hf_precompute_freqs_cis, + partial_rotary_factor=self.params.partial_rotary_factor, + ) else: self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + precompute_freqs_cis, + use_scaled=self.params.use_scaled_rope, + scale_factor=self.params.rope_scale_factor, + high_freq_factor=self.params.high_freq_factor, ) freqs_cos, freqs_sin = self.precompute_freqs_cis( self.params.head_dim, @@ -303,6 +318,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.rope = rope + self.use_qk_norm = args.use_qk_norm + self.qk_norm_before_rope = args.qk_norm_before_rope + if self.use_qk_norm: + q_norm_dim = self.head_dim + k_norm_dim = self.head_dim + self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) + self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + def forward( self, x: torch.Tensor, @@ -327,6 +350,10 @@ def forward( k = k.transpose(1, 2) v = v.transpose(1, 2) + if self.use_qk_norm and not self.qk_norm_before_rope: + q = self.q_norm_fn(q) + k = self.k_norm_fn(k) + new_k = k new_v = v