From 53f000f91fc809c4c395ba4d0c13f65f47c557b1 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Fri, 28 Feb 2025 13:39:11 -0800 Subject: [PATCH] Fix static_llama to read some previously hardcoded options from ModelArgs Differential Revision: D70414663 --- examples/qualcomm/oss_scripts/llama/model/static_llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 09cc7504224..40044db7428 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -37,7 +37,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() self.dim = config.dim self.n_heads = config.n_heads - self.head_dim = config.dim // config.n_heads + self.head_dim = config.head_dim self.n_kv_heads = config.n_kv_heads self.num_key_value_groups = config.n_heads // self.n_kv_heads self.max_seq_len = config.max_seq_len @@ -304,7 +304,7 @@ def __init__( ): super().__init__() self.dim = config.dim - self.head_dim = config.dim // config.n_heads + self.head_dim = config.head_dim self.max_batch_size = config.max_batch_size self.max_seq_len = config.max_seq_len self.n_heads = config.n_heads @@ -328,9 +328,11 @@ def __init__( self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) freqs_cos, freqs_sin = precompute_freqs_cis( - config.dim // config.n_heads, + config.head_dim, config.max_seq_len, config.rope_freq_base, + config.use_scaled_rope, + config.rope_scale_factor, ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False)