@@ -64,6 +64,8 @@ class ModelArgs:
6464 use_scaled_rope : bool = True # Use scaled RoPE, introduced in llama3.1.
6565 # Additional Model Metadata needed at runtime
6666 rope_scale_factor : int = 8
67+ high_freq_factor : int = 4
68+
6769 bos_idx : int = 1
6870 eos_idx : int = 3
6971 bos_count : int = - 1 # i.e., a single EOS is used as BOS
@@ -74,6 +76,9 @@ class ModelArgs:
7476
7577 use_cache_list : bool = True
7678
79+ use_kv_cache : bool = False
80+ enable_dynamic_shape : bool = False
81+
7782 def __post_init__ (self ):
7883 if self .n_kv_heads is None :
7984 self .n_kv_heads = self .n_heads
@@ -160,10 +165,16 @@ def __init__(self, params: ModelArgs):
160165 super ().__init__ ()
161166 self .params = params
162167 if self .params .use_hf_rope :
163- self .precompute_freqs_cis = hf_precompute_freqs_cis
168+ self .precompute_freqs_cis = partial (
169+ hf_precompute_freqs_cis ,
170+ partial_rotary_factor = self .params .partial_rotary_factor ,
171+ )
164172 else :
165173 self .precompute_freqs_cis = partial (
166- precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
174+ precompute_freqs_cis ,
175+ use_scaled = self .params .use_scaled_rope ,
176+ scale_factor = self .params .rope_scale_factor ,
177+ high_freq_factor = self .params .high_freq_factor ,
167178 )
168179 freqs_cos , freqs_sin = self .precompute_freqs_cis (
169180 self .params .head_dim ,
0 commit comments