@@ -37,7 +37,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
3737 super ().__init__ ()
3838 self .dim = config .dim
3939 self .n_heads = config .n_heads
40- self .head_dim = config .dim // config . n_heads
40+ self .head_dim = config .head_dim
4141 self .n_kv_heads = config .n_kv_heads
4242 self .num_key_value_groups = config .n_heads // self .n_kv_heads
4343 self .max_seq_len = config .max_seq_len
@@ -304,7 +304,7 @@ def __init__(
304304 ):
305305 super ().__init__ ()
306306 self .dim = config .dim
307- self .head_dim = config .dim // config . n_heads
307+ self .head_dim = config .head_dim
308308 self .max_batch_size = config .max_batch_size
309309 self .max_seq_len = config .max_seq_len
310310 self .n_heads = config .n_heads
@@ -328,9 +328,11 @@ def __init__(
328328 self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
329329 self .tok_embeddings = nn .Embedding (config .vocab_size , config .dim )
330330 freqs_cos , freqs_sin = precompute_freqs_cis (
331- config .dim // config . n_heads ,
331+ config .head_dim ,
332332 config .max_seq_len ,
333333 config .rope_freq_base ,
334+ config .use_scaled_rope ,
335+ config .rope_scale_factor ,
334336 )
335337 self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
336338 self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
0 commit comments