From c083d85f898e6b64d459893826d66dc1a19cb2d5 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Thu, 14 Nov 2024 15:03:17 -0800 Subject: [PATCH] allow customized head_dim This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/). Similar change in HF: https://github.com/huggingface/transformers/pull/32502 Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/) [ghstack-poisoned] --- examples/models/llama/llama_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 20b8b1e30d4..76cd218d65b 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -85,6 +85,7 @@ class ModelArgs: n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer hidden_dim: Optional[int] = None + head_dim: Optional[int] = None multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 @@ -272,7 +273,7 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_local_heads = self.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // self.n_heads + self.head_dim = args.dim // self.n_heads if args.head_dim is None else args.head_dim self.max_batch_size = args.max_batch_size self.max_seq_len = args.max_seq_len self.dim = args.dim