Skip to content

Commit 4da1095

Browse files
talentJay-uxSethHWeidman
authored andcommitted
Sliding window KV Cache bug fix (rasbt#925)
1. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset when window_size > context_length 2. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset 3. Fix KV Cache import issue for gpt_with_kv_cache_optimized
1 parent d50db6a commit 4da1095

File tree

2 files changed

+117
-3
lines changed

2 files changed

+117
-3
lines changed

ch04/03_kv-cache/gpt_with_kv_cache_optimized.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
3737
def forward(self, x, use_cache=False):
3838
b, num_tokens, d_in = x.shape
3939

40+
if use_cache:
41+
# to prevent self.ptr_cur became negative
42+
assert num_tokens <= self.window_size, (
43+
f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size}). "
44+
)
45+
4046
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
4147
values_new = self.W_value(x)
4248
queries = self.W_query(x)
@@ -221,6 +227,7 @@ def __init__(self, cfg):
221227

222228
self.final_norm = LayerNorm(cfg["emb_dim"])
223229
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
230+
self.kv_window_size = cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"]
224231

225232
def forward(self, in_idx, use_cache=False):
226233
batch_size, seq_len = in_idx.shape
@@ -232,6 +239,12 @@ def forward(self, in_idx, use_cache=False):
232239
# NEW
233240

234241
if use_cache:
242+
context_length = self.pos_emb.num_embeddings
243+
# to prevent generate more sequence than context_length
244+
# since longer than context_length will cause model out of bound error when reading the position embedding
245+
assert self.ptr_current_pos + seq_len <= context_length, (
246+
f"Position embedding overflow. Want to read {self.ptr_current_pos + seq_len} which excceded size of {context_length}"
247+
)
235248
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
236249
self.ptr_current_pos += seq_len
237250
else:
@@ -294,11 +307,24 @@ def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, u
294307
model.eval()
295308

296309
ctx_len = context_size or model.pos_emb.num_embeddings
310+
kv_window_size = model.kv_window_size
297311

298312
with torch.no_grad():
299313
if use_cache:
300314
model.reset_kv_cache()
301-
logits = model(idx[:, -ctx_len:], use_cache=True)
315+
316+
input_tokens = idx[:, -ctx_len:]
317+
input_tokens_length = input_tokens.size(1)
318+
319+
# prefill to handle input_tokens_length > kv_window_size
320+
for i in range(0, input_tokens_length, kv_window_size):
321+
chunk = input_tokens[:, i:i+kv_window_size]
322+
logits = model(chunk, use_cache=True)
323+
324+
# can't generate more than ctx_len of result
325+
# due to the limitation of position embedding
326+
max_generable = ctx_len - input_tokens_length
327+
max_new_tokens = min(max_new_tokens, max_generable)
302328

303329
for _ in range(max_new_tokens):
304330
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)

ch04/03_kv-cache/tests.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from gpt_with_kv_cache import GPTModel as GPTModelKV1
1111
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
12-
from gpt_with_kv_cache import generate_text_simple_cached
12+
from gpt_with_kv_cache import generate_text_simple_cached as generate_text_simple_cachedKV1
13+
from gpt_with_kv_cache_optimized import generate_text_simple_cached as generate_text_simple_cachedKV2
1314

1415

1516
GPT_CONFIG_124M = {
@@ -20,6 +21,7 @@
2021
"n_layers": 12,
2122
"drop_rate": 0.1,
2223
"qkv_bias": False,
24+
"kv_window_size": 1024 # NEW: KV cache window size
2325
}
2426

2527

@@ -80,8 +82,15 @@ def test_gpt_model_equivalence_cached(ModelClass):
8082
max_new_tokens=30,
8183
context_size=GPT_CONFIG_124M["context_length"]
8284
)
85+
elif ModelClass is GPTModelKV1:
86+
token_ids = generate_text_simple_cachedKV1(
87+
model=model,
88+
idx=encoded_tensor,
89+
max_new_tokens=30,
90+
context_size=GPT_CONFIG_124M["context_length"]
91+
)
8392
else:
84-
token_ids = generate_text_simple_cached(
93+
token_ids = generate_text_simple_cachedKV2(
8594
model=model,
8695
idx=encoded_tensor,
8796
max_new_tokens=30,
@@ -99,3 +108,82 @@ def test_gpt_model_equivalence_cached(ModelClass):
99108
assert torch.equal(base_output, other_output), (
100109
f"Mismatch between {base_name} and {other_name}"
101110
)
111+
112+
113+
def test_context_overflow_bug():
114+
"""
115+
Test that demonstrates the ptr_current_pos overflow bug.
116+
117+
In old implementation:
118+
- context_length = 10 (positions 0-9 available)
119+
- We try to generate 15 tokens total (5 input + 10 generated)
120+
- At token 11 (position 10), it crashes trying to access pos_emb[10]
121+
"""
122+
GPT_CONFIG_SMALL = {
123+
"vocab_size": 50257,
124+
"context_length": 10, # Very small context
125+
"emb_dim": 768,
126+
"n_heads": 12,
127+
"n_layers": 12,
128+
"drop_rate": 0.1,
129+
"qkv_bias": False,
130+
"kv_window_size": 20 # Larger than context_length
131+
}
132+
133+
torch.manual_seed(123)
134+
135+
model = GPTModelKV2(GPT_CONFIG_SMALL).to(device)
136+
model.eval()
137+
138+
# 5 input tokens
139+
input_tokens = torch.randint(0, 50257, (1, 5), device=device)
140+
141+
generate_text_simple_cachedKV2(
142+
model=model,
143+
idx=input_tokens,
144+
max_new_tokens=10, # 5 + 10 = 15 > 10 context_length
145+
context_size=GPT_CONFIG_SMALL["context_length"],
146+
use_cache=True
147+
)
148+
149+
150+
def test_prefill_chunking_basic():
151+
"""
152+
Test that prefill correctly chunks input when input_length > kv_window_size.
153+
154+
Setup:
155+
- kv_window_size = 4
156+
- input_length = 10
157+
- Should process in 3 chunks: [0:4], [4:8], [8:10]
158+
"""
159+
config = {
160+
"vocab_size": 50257,
161+
"context_length": 20,
162+
"emb_dim": 768,
163+
"n_heads": 12,
164+
"n_layers": 12,
165+
"drop_rate": 0.1,
166+
"qkv_bias": False,
167+
"kv_window_size": 4 # Small window to force chunking
168+
}
169+
170+
torch.manual_seed(123)
171+
model = GPTModelKV2(config).to(device)
172+
model.eval()
173+
174+
# 10 input tokens (> kv_window_size of 4)
175+
input_tokens = torch.randint(0, 50257, (1, 10), device=device)
176+
177+
# Should successfully process all input in chunks
178+
token_ids = generate_text_simple_cachedKV2(
179+
model=model,
180+
idx=input_tokens,
181+
max_new_tokens=2,
182+
use_cache=True
183+
)
184+
185+
# Should have 10 input + 2 generated = 12 total
186+
assert token_ids.shape[1] == 12, f"Expected 12 tokens, got {token_ids.shape[1]}"
187+
188+
# First 10 tokens should match input
189+
assert torch.equal(token_ids[:, :10], input_tokens), "Input tokens should be preserved"

0 commit comments

Comments
 (0)