diff --git a/examples/apple/coreml/llama/llama_transformer.py b/examples/apple/coreml/llama/llama_transformer.py index 324f4aa1f2e..857266950c3 100644 --- a/examples/apple/coreml/llama/llama_transformer.py +++ b/examples/apple/coreml/llama/llama_transformer.py @@ -64,6 +64,8 @@ class ModelArgs: use_scaled_rope: bool = True # Use scaled RoPE, introduced in llama3.1. # Additional Model Metadata needed at runtime rope_scale_factor: int = 8 + high_freq_factor: int = 4 + bos_idx: int = 1 eos_idx: int = 3 bos_count: int = -1 # i.e., a single EOS is used as BOS @@ -74,6 +76,9 @@ class ModelArgs: use_cache_list: bool = True + use_kv_cache: bool = False + enable_dynamic_shape: bool = False + def __post_init__(self): if self.n_kv_heads is None: self.n_kv_heads = self.n_heads @@ -160,10 +165,16 @@ def __init__(self, params: ModelArgs): super().__init__() self.params = params if self.params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis + self.precompute_freqs_cis = partial( + hf_precompute_freqs_cis, + partial_rotary_factor=self.params.partial_rotary_factor, + ) else: self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=self.params.use_scaled_rope + precompute_freqs_cis, + use_scaled=self.params.use_scaled_rope, + scale_factor=self.params.rope_scale_factor, + high_freq_factor=self.params.high_freq_factor, ) freqs_cos, freqs_sin = self.precompute_freqs_cis( self.params.head_dim,