@@ -143,6 +143,9 @@ def __post_init__(self):
143143 hidden_dim = int (self .ffn_dim_multiplier * hidden_dim )
144144 self .hidden_dim = find_multiple (hidden_dim , multiple_of )
145145
146+ if self .head_dim is None :
147+ self .head_dim = self .dim // self .n_heads
148+
146149
147150class KVCache (nn .Module ):
148151 def __init__ (
@@ -273,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
273276 self .n_local_heads = self .n_heads // model_parallel_size
274277 self .n_local_kv_heads = self .n_kv_heads // model_parallel_size
275278 self .n_rep = self .n_local_heads // self .n_local_kv_heads
276- self .head_dim = args .dim // self . n_heads if args . head_dim is None else args . head_dim
279+ self .head_dim = args .head_dim
277280 self .max_batch_size = args .max_batch_size
278281 self .max_seq_len = args .max_seq_len
279282 self .dim = args .dim
@@ -426,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
426429 self .use_kv_cache = args .use_kv_cache
427430 self .n_heads = args .n_heads
428431 self .dim = args .dim
429- self .head_dim = args .dim // args . n_heads if args . head_dim is None else args . head_dim
432+ self .head_dim = args .head_dim
430433 self .attention = Attention (args , layer_id )
431434 if args .moe :
432435 self .block_sparse_moe = MOEFeedForward (args )
@@ -473,7 +476,7 @@ def __init__(self, params: ModelArgs):
473476 precompute_freqs_cis , use_scaled = params .use_scaled_rope
474477 )
475478 freqs_cos , freqs_sin = self .precompute_freqs_cis (
476- params .dim // params . n_heads if params . head_dim is None else params . head_dim ,
479+ params .head_dim ,
477480 (
478481 params .max_seq_len # Normal llama2.
479482 if params .ffn_dim_multiplier is None
0 commit comments