@@ -377,8 +377,26 @@ def forward(
377
377
# - T | time-step (sequence length)
378
378
# - C | model's embeddings size (n_embd)
379
379
# - C* | attentions's embeddings size
380
- # - nh_(q,k,v) | number of heads for query, key and value
381
380
# - hs | head size
381
+ # - nh_(q,k,v) | number of heads for query, key and value
382
+ # - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads
383
+ # alternative notation: num_kv_groups = n_query_groups
384
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
385
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
386
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
387
+ # │ │ │ │ │ │ │
388
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
389
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
390
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
391
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
392
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
393
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
394
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
395
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
396
+ # MHA GQA MQA
397
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
398
+ #
399
+ # credit https://arxiv.org/pdf/2305.13245.pdf
382
400
head_size = self .config .head_size
383
401
n_head = self .config .n_head
384
402
n_query_groups = self .config .n_query_groups
0 commit comments