Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit cc0eaf1

Browse files
[Bugfix] spec decode handle None entries in topk args in create_sequence_group_output (vllm-project#7232)
Signed-off-by: Travis Johnson <[email protected]>
1 parent 955b519 commit cc0eaf1

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed

tests/spec_decode/e2e/test_logprobs.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,78 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
343343
b=baseline_rank_to_logprob[rank],
344344
abs_tol=1e-1,
345345
)
346+
347+
348+
@pytest.mark.parametrize(
349+
"common_llm_kwargs",
350+
[{
351+
"model": "JackFram/llama-160m",
352+
# Skip cuda graph recording for fast test.
353+
"enforce_eager": True,
354+
# Required for spec decode.
355+
"use_v2_block_manager": True,
356+
"max_logprobs": 6,
357+
}])
358+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
359+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
360+
@pytest.mark.parametrize("test_llm_kwargs",
361+
[{
362+
"speculative_model": "JackFram/llama-68m",
363+
"num_speculative_tokens": 3,
364+
"disable_logprobs_during_spec_decoding": True,
365+
}])
366+
@pytest.mark.parametrize("seed", [1])
367+
def test_logprobs_disabled(baseline_llm_generator, test_llm_generator):
368+
"""Check the behavior when logprobs are disabled.
369+
Token choices should match with the base model.
370+
"""
371+
prompts = [
372+
"Hello, my name is",
373+
"The president of the United States is",
374+
"The capital of France is",
375+
"The future of AI is",
376+
"San Francisco is know for its",
377+
"Facebook was created in 2004 by",
378+
"Curious George is a",
379+
"Python 3.11 brings improvements to its",
380+
]
381+
382+
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))]
383+
384+
sampling_params = SamplingParams(
385+
# Use smaller output len for fast test
386+
max_tokens=7,
387+
ignore_eos=True,
388+
temperature=0.0,
389+
logprobs=2,
390+
)
391+
392+
spec_batch_logprobs = get_logprobs_from_llm_generator(
393+
test_llm_generator, prompts, sampling_params)
394+
baseline_batch_logprobs = get_logprobs_from_llm_generator(
395+
baseline_llm_generator, prompts, sampling_params)
396+
397+
assert len(baseline_batch_logprobs) == len(prompts)
398+
assert len(spec_batch_logprobs) == len(prompts)
399+
400+
# For each sequence in the batch.
401+
for _, (baseline_logprobs, spec_logprobs) in enumerate(
402+
zip(baseline_batch_logprobs, spec_batch_logprobs)):
403+
assert len(spec_logprobs) == len(baseline_logprobs)
404+
405+
# For each generated position of the sequence.
406+
for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
407+
zip(spec_logprobs, baseline_logprobs)):
408+
409+
assert len(spec_pos_logprobs) == 1
410+
spec_top_token_id = list(spec_pos_logprobs)[0]
411+
412+
spec_top_logprob = spec_pos_logprobs[spec_top_token_id]
413+
assert spec_top_logprob.logprob == 0.0
414+
assert spec_top_logprob.rank == -1
415+
416+
# check that the chosen token matches the base model
417+
baseline_logprob = baseline_pos_logprobs[spec_top_token_id]
418+
assert baseline_logprob.rank == 1
419+
assert spec_top_logprob.decoded_token \
420+
== baseline_logprob.decoded_token

vllm/spec_decode/util.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,25 @@ def create_sequence_group_output(
6464
token_id_logprob_rank (int): The logprob rank of the sampled token.
6565
token_id_logprob (float): The logprob value of the sampled token.
6666
seq_id (int): The sequence id.
67-
topk_token_ids (List[int]): The list of top-k token ids.
68-
topk_logprobs (List[float]): The list of top-k logprobs.
67+
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
68+
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
6969
"""
7070
# vLLM logprobs always include the sampled token. In addition, the user may
7171
# request topk-logprobs (where top-k varies per user up to max_logprobs).
72-
logprobs: Dict[Optional[int], Logprob] = {
72+
logprobs: Dict[int, Logprob] = {
7373
token_id: Logprob(
7474
logprob=token_id_logprob,
7575
rank=token_id_logprob_rank,
7676
),
7777
}
7878
logprobs.update({
79-
topk_token_ids[topk_logprob_index]: Logprob(
80-
logprob=topk_logprobs[topk_logprob_index],
81-
rank=topk_logprob_index + 1,
79+
topk_token_id: Logprob(
80+
logprob=topk_logprob if topk_logprob is not None else 0.0,
81+
rank=topk_index + 1,
8282
)
83-
for topk_logprob_index, _ in enumerate(topk_token_ids)
83+
for topk_index, (topk_token_id, topk_logprob) \
84+
in enumerate(zip(topk_token_ids, topk_logprobs)) \
85+
if topk_token_id is not None
8486
})
8587

8688
return CompletionSequenceGroupOutput(

0 commit comments

Comments
 (0)