Skip to content

Commit d32a738

Browse files
authored
doc: add comments for clarifying query / KV groups (#2093)
1 parent 27ff93d commit d32a738

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

litgpt/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def forward(
408408
qkv = self.qkv(x) # (B, T, 3xC*)
409409

410410
# Define query, key and value sizes.
411-
# If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`).
411+
# If grouped/multi query is enabled, these sizes are not equal (see the diagram above).
412412
query_size = n_head * head_size
413413
key_size = value_size = n_query_groups * head_size
414414
# Split qkv into query, key and value matrices.
@@ -420,9 +420,12 @@ def forward(
420420

421421
# To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the
422422
# embedding size (C) into num_heads (nh) and head_size (hs).
423+
424+
# The original GQA paper is followed here and the term query groups is used.
425+
# alternative notation: Query groups are also referred to as KV groups.
423426
q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs)
424-
k = k.view(B, T, n_query_groups, head_size) # (B, T, nh_k, hs)
425-
v = v.view(B, T, n_query_groups, head_size) # (B, T, nh_v, hs)
427+
k = k.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs)
428+
v = v.view(B, T, n_query_groups, head_size) # (B, T, n_query_groups, hs)
426429

427430
# The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are
428431
# multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector

0 commit comments

Comments
 (0)