Skip to content

Commit 1c87ce3

Browse files
committed
Update on "allow customized head_dim"
This is for resolving the ask in this [post](https://fb.workplace.com/groups/pytorch.edge.users/permalink/1574875706716050/). Similar change in HF: huggingface/transformers#32502 Differential Revision: [D65974454](https://our.internmc.facebook.com/intern/diff/D65974454/) [ghstack-poisoned]
2 parents 0120876 + b0b44b3 commit 1c87ce3

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def __post_init__(self):
143143
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
144144
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
145145

146+
if self.head_dim is None:
147+
self.head_dim = self.dim // self.n_heads
148+
146149

147150
class KVCache(nn.Module):
148151
def __init__(
@@ -273,7 +276,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
273276
self.n_local_heads = self.n_heads // model_parallel_size
274277
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
275278
self.n_rep = self.n_local_heads // self.n_local_kv_heads
276-
self.head_dim = args.dim // self.n_heads if args.head_dim is None else args.head_dim
279+
self.head_dim = args.head_dim
277280
self.max_batch_size = args.max_batch_size
278281
self.max_seq_len = args.max_seq_len
279282
self.dim = args.dim
@@ -426,7 +429,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
426429
self.use_kv_cache = args.use_kv_cache
427430
self.n_heads = args.n_heads
428431
self.dim = args.dim
429-
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
432+
self.head_dim = args.head_dim
430433
self.attention = Attention(args, layer_id)
431434
if args.moe:
432435
self.block_sparse_moe = MOEFeedForward(args)
@@ -473,7 +476,7 @@ def __init__(self, params: ModelArgs):
473476
precompute_freqs_cis, use_scaled=params.use_scaled_rope
474477
)
475478
freqs_cos, freqs_sin = self.precompute_freqs_cis(
476-
params.dim // params.n_heads if params.head_dim is None else params.head_dim,
479+
params.head_dim,
477480
(
478481
params.max_seq_len # Normal llama2.
479482
if params.ffn_dim_multiplier is None

0 commit comments

Comments
 (0)