@@ -260,21 +260,22 @@ class Attention(nn.Module):
260260 def __init__ (self , args : ModelArgs , layer_id : int ):
261261 super ().__init__ ()
262262 self .use_kv_cache = args .use_kv_cache
263- self .n_kv_heads = args .n_heads if args .n_kv_heads is None else args .n_kv_heads
264- assert args .n_heads % self .n_kv_heads == 0
263+ self .n_heads = args .n_heads
264+ self .n_kv_heads = self .n_heads if args .n_kv_heads is None else args .n_kv_heads
265+ assert self .n_heads % self .n_kv_heads == 0
265266 model_parallel_size = 1
266- self .n_local_heads = args .n_heads // model_parallel_size
267+ self .n_local_heads = self .n_heads // model_parallel_size
267268 self .n_local_kv_heads = self .n_kv_heads // model_parallel_size
268269 self .n_rep = self .n_local_heads // self .n_local_kv_heads
269- self .head_dim = args .dim // args .n_heads
270+ self .head_dim = args .dim // self .n_heads
270271 self .max_batch_size = args .max_batch_size
271272 self .max_seq_len = args .max_seq_len
272273 self .dim = args .dim
273- # args .dim = 4096, args .n_heads = 32, self.head_dim = 4096 / 32 = 125
274- self .wq = nn .Linear (args .dim , args .n_heads * self .head_dim , bias = False )
275- self .wk = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
276- self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
277- self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
274+ # self .dim = 4096, self .n_heads = 32, self.head_dim = 4096 / 32 = 125
275+ self .wq = nn .Linear (self .dim , self .n_heads * self .head_dim , bias = False )
276+ self .wk = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
277+ self .wv = nn .Linear (self .dim , self .n_kv_heads * self .head_dim , bias = False )
278+ self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
278279
279280 self .layer_id = layer_id
280281
0 commit comments