Skip to content

Commit ac2c774

Browse files
authored
test: add a diagnostic script for prefix caching naning (#1987)
Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent 42f3043 commit ac2c774

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

docs/adding-new-models.md

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,4 +311,54 @@ The different compilation modes offer distinct trade-offs between accuracy and p
311311
- **Eager vs CUDA Graph failures are normal** - don't panic if this fails
312312
- **Focus on patterns** - some models are more sensitive than others
313313
- **Use as guidance** - helps choose reliable compilation settings
314-
- **Balance precision vs performance** - choose what works for your use case
314+
- **Balance precision vs performance** - choose what works for your use case
315+
316+
## [5.prefix_caching_nan.py](https://github.com/NVIDIA-NeMo/RL/blob/main/tools/model_diagnostics/5.prefix_caching_nan.py)
317+
318+
Tests that prefix caching doesn't produce NaN logprobs when prior generation is rolled back into the prompt (the standard RL / multi-turn pattern). In vLLM >= 0.14, the second request can return all-NaN logprobs with `token_id=0` (`<unk>`) for every token after the first.
319+
320+
```sh
321+
# Single version (requires 2+ GPUs for TP=2)
322+
uv run --no-project --with "vllm==0.14.0" tools/model_diagnostics/5.prefix_caching_nan.py
323+
324+
# Test across multiple vLLM versions:
325+
for ver in 0.11.2 0.13.0 0.14.0 0.15.0 0.15.1; do
326+
uv run --no-project --with "vllm==$ver" tools/model_diagnostics/5.prefix_caching_nan.py 2>&1 | tee "prefix_caching_nan_vllm_${ver}.log"
327+
done
328+
```
329+
330+
Expected pass output (vLLM 0.13.0):
331+
```
332+
Iteration 1 — prompt length: 13990 chars
333+
tokens: 2048, finish_reason: length
334+
text (first 100): '3001 3002 3003 3004 3005 3006 3007 3008 3009 3010 3011 3012 3013 3014 3015 3016 3017 3018 3019 3020 '
335+
336+
Iteration 2 — prompt length: 16038 chars
337+
tokens: 2048, finish_reason: length
338+
text (first 100): '1 3412 3413 3414 3415 3416 3417 3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429 3430 343'
339+
340+
[nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16] ALL GOOD!
341+
```
342+
343+
Expected fail output (vLLM 0.14.0):
344+
```
345+
Iteration 1 — prompt length: 13990 chars
346+
tokens: 2048, finish_reason: length
347+
text (first 100): '3000\n\nAssistant: 2600 2601 2602 2603 2604 2605 2606 2607 2609 2610 ...'
348+
349+
Iteration 2 — prompt length: 16047 chars
350+
tokens: 2048, finish_reason: length
351+
text (first 100): '3'
352+
353+
Sample logprobs from iteration 2:
354+
token[0] id=1051: logprob=-0.0005862186080776155 decoded='3'
355+
token[1] id=0: logprob=nan decoded='<unk>'
356+
token[2] id=0: logprob=nan decoded='<unk>'
357+
token[2047] id=0: logprob=nan decoded='<unk>'
358+
359+
AssertionError: FAIL: 2047/2048 logprobs are NaN on iteration 2 (prefix caching is broken in vLLM 0.14.0)
360+
```
361+
362+
Note: the `ERROR ... Engine core proc EngineCore_DP0 died unexpectedly` message that may appear after the assertion is just vLLM's engine shutting down ungracefully after the process exits — it is not a separate issue.
363+
364+
The script generates from a counting prompt, appends the output back into the prompt, and generates again. On the second generation, prefix caching reuses the KV cache from the first request's prefix. The bug causes the cached prefix to produce corrupted activations, resulting in `token_id=0` (`<unk>`) with `logprob=nan` for all tokens after the first.

pyrefly.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ project-includes = [
134134
"tools/model_diagnostics/1.max_model_len_respected.py",
135135
"tools/model_diagnostics/2.long_generation_decode_vs_prefill.py",
136136
"tools/model_diagnostics/4.vllm_precision_compilation_test.py",
137+
"tools/model_diagnostics/5.prefix_caching_nan.py",
137138
]
138139

139140
# Disable TypedDict mutation errors since TypedDict objects are regular dicts at runtime
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Prefix caching NaN reproducer.
15+
16+
Tests that prefix caching doesn't produce NaN logprobs when prior generation
17+
is rolled back into the prompt (the standard RL / multi-turn pattern).
18+
19+
Known failure: vLLM >= 0.14 may return token_id=0 (<unk>) with logprob=nan
20+
for every token after the first on the second request.
21+
22+
Usage:
23+
python 5.prefix_caching_nan.py
24+
python 5.prefix_caching_nan.py --model meta-llama/Llama-3.1-8B-Instruct
25+
"""
26+
27+
import argparse
28+
import math
29+
30+
MODEL = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"
31+
TP = 2
32+
MAX_TOKENS = 2048
33+
MAX_MODEL_LEN = 32768
34+
COUNT_UP_TO = 3000
35+
36+
parser = argparse.ArgumentParser()
37+
parser.add_argument("--model", type=str, default=MODEL)
38+
parser.add_argument("--tp", type=int, default=TP)
39+
args = parser.parse_args()
40+
41+
import vllm
42+
from vllm import LLM, SamplingParams
43+
44+
print(f"vLLM version: {vllm.__version__}")
45+
46+
numbers = " ".join(str(i) for i in range(1, COUNT_UP_TO + 1))
47+
prompt = (
48+
"You are a counting assistant. Output ONLY numbers separated by spaces.\n\n"
49+
f"User: Continue counting: {numbers} "
50+
)
51+
52+
llm = LLM(
53+
model=args.model,
54+
tensor_parallel_size=args.tp,
55+
enable_prefix_caching=True,
56+
max_model_len=MAX_MODEL_LEN,
57+
gpu_memory_utilization=0.90,
58+
trust_remote_code=True,
59+
)
60+
sampling_params = SamplingParams(temperature=0.0, max_tokens=MAX_TOKENS, logprobs=1)
61+
62+
# Iteration 1: initial generation (builds the prefix cache)
63+
print(f"\nIteration 1 — prompt length: {len(prompt)} chars")
64+
out1 = llm.generate([prompt], sampling_params)[0].outputs[0]
65+
print(f" tokens: {len(out1.token_ids)}, finish_reason: {out1.finish_reason}")
66+
print(f" text (first 100): {out1.text[:100]!r}")
67+
68+
# Iteration 2: extend prompt with prior output (triggers prefix cache reuse)
69+
prompt += out1.text
70+
print(f"\nIteration 2 — prompt length: {len(prompt)} chars")
71+
out2 = llm.generate([prompt], sampling_params)[0].outputs[0]
72+
print(f" tokens: {len(out2.token_ids)}, finish_reason: {out2.finish_reason}")
73+
print(f" text (first 100): {out2.text[:100]!r}")
74+
75+
# Check for NaN logprobs
76+
nan_count = 0
77+
if out2.logprobs:
78+
for step in out2.logprobs:
79+
if step is None:
80+
continue
81+
for _tid, lp_obj in step.items():
82+
lp = lp_obj.logprob if hasattr(lp_obj, "logprob") else lp_obj
83+
if isinstance(lp, float) and math.isnan(lp):
84+
nan_count += 1
85+
break
86+
87+
if nan_count > 0:
88+
print("\n Sample logprobs from iteration 2:")
89+
for idx in [0, 1, 2, len(out2.logprobs) - 1]:
90+
if idx < len(out2.logprobs) and out2.logprobs[idx] is not None:
91+
for tid, lp_obj in out2.logprobs[idx].items():
92+
lp = lp_obj.logprob if hasattr(lp_obj, "logprob") else lp_obj
93+
decoded = (
94+
lp_obj.decoded_token if hasattr(lp_obj, "decoded_token") else "?"
95+
)
96+
print(f" token[{idx}] id={tid}: logprob={lp} decoded={decoded!r}")
97+
break
98+
99+
assert nan_count == 0, (
100+
f"FAIL: {nan_count}/{len(out2.token_ids)} logprobs are NaN on iteration 2 "
101+
f"(prefix caching is broken in vLLM {vllm.__version__})"
102+
)
103+
print(f"\n[{args.model}] ALL GOOD!")

0 commit comments

Comments
 (0)