Skip to content

Commit 27ff93d

Browse files
authored
doc: add n_query_groups to attention notation table (#2092)
1 parent 50f6bc4 commit 27ff93d

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

litgpt/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,26 @@ def forward(
377377
# - T | time-step (sequence length)
378378
# - C | model's embeddings size (n_embd)
379379
# - C* | attentions's embeddings size
380-
# - nh_(q,k,v) | number of heads for query, key and value
381380
# - hs | head size
381+
# - nh_(q,k,v) | number of heads for query, key and value
382+
# - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads
383+
# alternative notation: num_kv_groups = n_query_groups
384+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
385+
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
386+
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
387+
# │ │ │ │ │ │ │
388+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
389+
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
390+
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
391+
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
392+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
393+
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
394+
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
395+
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
396+
# MHA GQA MQA
397+
# n_query_groups=4 n_query_groups=2 n_query_groups=1
398+
#
399+
# credit https://arxiv.org/pdf/2305.13245.pdf
382400
head_size = self.config.head_size
383401
n_head = self.config.n_head
384402
n_query_groups = self.config.n_query_groups

0 commit comments

Comments
 (0)