Skip to content

Commit d50db6a

Browse files
committed
Update GQA comments
1 parent 5b9df45 commit d50db6a

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

ch04/04_gqa/gpt_with_kv_gqa_reference.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8181
num_tokens_q = queries.shape[-2]
8282
num_tokens_k = keys.shape[-2]
8383
device = queries.device
84+
85+
# Causal Masking with a KV Cache
86+
# ------------------------------
87+
# To mask correctly, we must align the Query and Key tensors using their
88+
# "Absolute Positions" in the full text sequence.
89+
#
90+
# 1. Queries: The new tokens start at `self.ptr_current_pos`.
91+
#
92+
# 2. Keys: In this infinite-cache implementation, the cache always begins
93+
# at Absolute Position 0.
94+
#
95+
# (Note: If we were using a sliding window, we would calculate the start
96+
# position as `total_tokens_processed - current_cache_size`).
8497
q_positions = torch.arange(
8598
self.ptr_current_pos,
8699
self.ptr_current_pos + num_tokens_q,

0 commit comments

Comments
 (0)