1212
1313import torch
1414import torch .nn .functional as F
15+ from executorch .examples .models .llama .norm import RMSNorm
1516
1617from executorch .examples .models .llama .rope import (
1718 hf_apply_rotary_emb ,
@@ -64,6 +65,8 @@ class ModelArgs:
6465 use_scaled_rope : bool = True # Use scaled RoPE, introduced in llama3.1.
6566 # Additional Model Metadata needed at runtime
6667 rope_scale_factor : int = 8
68+ high_freq_factor : int = 4
69+
6770 bos_idx : int = 1
6871 eos_idx : int = 3
6972 bos_count : int = - 1 # i.e., a single EOS is used as BOS
@@ -74,6 +77,12 @@ class ModelArgs:
7477
7578 use_cache_list : bool = True
7679
80+ use_kv_cache : bool = False
81+ enable_dynamic_shape : bool = False
82+
83+ use_qk_norm : bool = False
84+ qk_norm_before_rope : bool = False
85+
7786 def __post_init__ (self ):
7887 if self .n_kv_heads is None :
7988 self .n_kv_heads = self .n_heads
@@ -96,7 +105,7 @@ def __post_init__(self):
96105 self .head_dim = self .dim // self .n_heads
97106
98107
99- class RMSNorm (torch .nn .Module ):
108+ class CoreMLRMSNorm (torch .nn .Module ):
100109 def __init__ (self , dim : int , eps : float = 1e-6 ):
101110 """
102111 Initialize the RMSNorm normalization layer.
@@ -160,10 +169,16 @@ def __init__(self, params: ModelArgs):
160169 super ().__init__ ()
161170 self .params = params
162171 if self .params .use_hf_rope :
163- self .precompute_freqs_cis = hf_precompute_freqs_cis
172+ self .precompute_freqs_cis = partial (
173+ hf_precompute_freqs_cis ,
174+ partial_rotary_factor = self .params .partial_rotary_factor ,
175+ )
164176 else :
165177 self .precompute_freqs_cis = partial (
166- precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
178+ precompute_freqs_cis ,
179+ use_scaled = self .params .use_scaled_rope ,
180+ scale_factor = self .params .rope_scale_factor ,
181+ high_freq_factor = self .params .high_freq_factor ,
167182 )
168183 freqs_cos , freqs_sin = self .precompute_freqs_cis (
169184 self .params .head_dim ,
@@ -303,6 +318,14 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
303318
304319 self .rope = rope
305320
321+ self .use_qk_norm = args .use_qk_norm
322+ self .qk_norm_before_rope = args .qk_norm_before_rope
323+ if self .use_qk_norm :
324+ q_norm_dim = self .head_dim
325+ k_norm_dim = self .head_dim
326+ self .q_norm_fn = RMSNorm (q_norm_dim , eps = args .norm_eps )
327+ self .k_norm_fn = RMSNorm (k_norm_dim , eps = args .norm_eps )
328+
306329 def forward (
307330 self ,
308331 x : torch .Tensor ,
@@ -327,6 +350,10 @@ def forward(
327350 k = k .transpose (1 , 2 )
328351 v = v .transpose (1 , 2 )
329352
353+ if self .use_qk_norm and not self .qk_norm_before_rope :
354+ q = self .q_norm_fn (q )
355+ k = self .k_norm_fn (k )
356+
330357 new_k = k
331358 new_v = v
332359
0 commit comments