diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 4d39d131d1d..fce1340f7ac 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -260,21 +260,22 @@ class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_id: int): super().__init__() self.use_kv_cache = args.use_kv_cache - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - assert args.n_heads % self.n_kv_heads == 0 + self.n_heads = args.n_heads + self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads + assert self.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 - self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_heads = self.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads + self.head_dim = args.dim // self.n_heads self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim - # args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125 - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + # self.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125 + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id