@@ -265,21 +265,22 @@ class Attention(nn.Module):
265265 def __init__ (self , args : ModelArgs , layer_id : int ):
266266 super ().__init__ ()
267267 self .use_kv_cache = args .use_kv_cache
268- self .n_kv_heads = args .n_heads if args .n_kv_heads is None else args .n_kv_heads
269- assert args .n_heads % self .n_kv_heads == 0
268+ self .n_heads = args .n_heads
269+ self .n_kv_heads = self .n_heads if args .n_kv_heads is None else args .n_kv_heads
270+ assert self .n_heads % self .n_kv_heads == 0
270271 model_parallel_size = 1
271- self .n_local_heads = args .n_heads // model_parallel_size
272+ self .n_local_heads = self .n_heads // model_parallel_size
272273 self .n_local_kv_heads = self .n_kv_heads // model_parallel_size
273274 self .n_rep = self .n_local_heads // self .n_local_kv_heads
274- self .head_dim = args .dim // args .n_heads
275+ self .head_dim = args .dim // self .n_heads
275276 self .max_batch_size = args .max_batch_size
276277 self .max_seq_len = args .max_seq_len
277278 self .dim = args .dim
278- # args .dim = 4096, args .n_heads = 32, self.head_dim = 4096 / 32 = 125
279- self .wq = nn .Linear (args .dim , args .n_heads * self .head_dim , bias = False )
280- self .wk = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
281- self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
282- self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
279+ # self .dim = 4096, self .n_heads = 32, self.head_dim = 4096 / 32 = 125
280+ self .wq = nn .Linear (self .dim , self .n_heads * self .head_dim , bias = False )
281+ self .wk = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
282+ self .wv = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
283+ self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
283284
284285 self .layer_id = layer_id
285286
0 commit comments