@@ -408,7 +408,7 @@ def forward(
408
408
qkv = self .qkv (x ) # (B, T, 3xC*)
409
409
410
410
# Define query, key and value sizes.
411
- # If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config` ).
411
+ # If grouped/multi query is enabled, these sizes are not equal (see the diagram above ).
412
412
query_size = n_head * head_size
413
413
key_size = value_size = n_query_groups * head_size
414
414
# Split qkv into query, key and value matrices.
@@ -420,9 +420,12 @@ def forward(
420
420
421
421
# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
422
422
# embedding size (C) into num_heads (nh) and head_size (hs).
423
+
424
+ # The original GQA paper is followed here and the term query groups is used.
425
+ # alternative notation: Query groups are also referred to as KV groups.
423
426
q = q .view (B , T , n_head , head_size ) # (B, T, nh_q, hs)
424
- k = k .view (B , T , n_query_groups , head_size ) # (B, T, nh_k , hs)
425
- v = v .view (B , T , n_query_groups , head_size ) # (B, T, nh_v , hs)
427
+ k = k .view (B , T , n_query_groups , head_size ) # (B, T, n_query_groups , hs)
428
+ v = v .view (B , T , n_query_groups , head_size ) # (B, T, n_query_groups , hs)
426
429
427
430
# The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are
428
431
# multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector
0 commit comments