Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit a3fce56

Browse files
[Speculative Decoding] EAGLE Implementation with Top-1 proposer (vllm-project#6830)
1 parent b3856be commit a3fce56

File tree

17 files changed

+854
-83
lines changed

17 files changed

+854
-83
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""This docstring details important information on the testing methodology.
2+
3+
Most of the tests rely on "greedy equality", where we expect the output of
4+
speculative decoding on a sequence to exactly match the output of normal non-
5+
speculative decoding.
6+
7+
Since speculative decoding with rejection sampling guarantees that the output
8+
distribution matches the target model's output distribution (up to hardware
9+
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
10+
equality.
11+
12+
However, we still need to verify below scenario could be passed:
13+
* Batch size 1 greedy equality
14+
* Batch size >1 greedy equality
15+
* Test greedy equality under preemption
16+
* Test greedy equality under various number of speculative tokens.
17+
18+
With those tests, we can say at least, EAGLE would not break the
19+
correctess for the target model outputs.
20+
"""
21+
22+
import pytest
23+
24+
from .conftest import run_greedy_equality_correctness_test
25+
26+
# main model
27+
MAIN_MODEL = "JackFram/llama-68m"
28+
29+
# speculative model
30+
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
31+
32+
# max. number of speculative tokens: this corresponds to
33+
# num_heads in the config.json of the speculator model.
34+
MAX_SPEC_TOKENS = 4
35+
36+
# precision
37+
PRECISION = "float32"
38+
39+
40+
@pytest.mark.parametrize(
41+
"common_llm_kwargs",
42+
[{
43+
# Skip cuda graph recording for fast test.
44+
"enforce_eager": True,
45+
46+
# Required for spec decode.
47+
"use_v2_block_manager": True,
48+
49+
# Print spec metrics.
50+
"disable_log_stats": False,
51+
52+
# Precision
53+
"dtype": PRECISION,
54+
55+
# Main model
56+
"model": MAIN_MODEL,
57+
}])
58+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
59+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
60+
@pytest.mark.parametrize("test_llm_kwargs", [
61+
{
62+
"speculative_model": SPEC_MODEL,
63+
"num_speculative_tokens": MAX_SPEC_TOKENS,
64+
},
65+
])
66+
@pytest.mark.parametrize("output_len", [
67+
128,
68+
])
69+
@pytest.mark.parametrize("batch_size", [1, 32])
70+
@pytest.mark.parametrize("seed", [1])
71+
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
72+
test_llm_generator, batch_size: int,
73+
output_len: int):
74+
"""Verify greedy equality with different batch size."""
75+
run_greedy_equality_correctness_test(baseline_llm_generator,
76+
test_llm_generator,
77+
batch_size,
78+
max_output_len=output_len,
79+
force_output_len=True)
80+
81+
82+
@pytest.mark.parametrize(
83+
"common_llm_kwargs",
84+
[{
85+
"enforce_eager": False,
86+
87+
# Required for spec decode.
88+
"use_v2_block_manager": True,
89+
90+
# Print spec metrics.
91+
"disable_log_stats": False,
92+
93+
# Precision
94+
"dtype": PRECISION,
95+
96+
# Main model
97+
"model": MAIN_MODEL,
98+
}])
99+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
100+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
101+
@pytest.mark.parametrize("test_llm_kwargs", [
102+
{
103+
"speculative_model": SPEC_MODEL,
104+
"num_speculative_tokens": MAX_SPEC_TOKENS,
105+
},
106+
])
107+
@pytest.mark.parametrize("output_len", [
108+
128,
109+
])
110+
@pytest.mark.parametrize("batch_size", [1, 32])
111+
@pytest.mark.parametrize("seed", [1])
112+
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
113+
test_llm_generator,
114+
batch_size: int,
115+
output_len: int):
116+
"""Verify greedy equality with cuda graph enabled and different
117+
batch sizes."""
118+
run_greedy_equality_correctness_test(baseline_llm_generator,
119+
test_llm_generator,
120+
batch_size,
121+
max_output_len=output_len,
122+
force_output_len=True)
123+
124+
125+
@pytest.mark.parametrize(
126+
"common_llm_kwargs",
127+
[{
128+
"block_size": 8,
129+
# 2 for small prompt, 256//8 for generated.
130+
"num_gpu_blocks_override": 2 + 256 // 8,
131+
"max_model_len": (2 + 256 // 8) * 8,
132+
133+
# Skip cuda graph recording for fast test.
134+
"enforce_eager": True,
135+
136+
# Required for spec decode.
137+
"use_v2_block_manager": True,
138+
139+
# Precision
140+
"dtype": PRECISION,
141+
142+
# Main model
143+
"model": MAIN_MODEL,
144+
}])
145+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
146+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
147+
@pytest.mark.parametrize("test_llm_kwargs", [
148+
{
149+
"speculative_model": SPEC_MODEL,
150+
"num_speculative_tokens": MAX_SPEC_TOKENS,
151+
},
152+
])
153+
@pytest.mark.parametrize(
154+
"output_len",
155+
[
156+
# Use small output len for fast test.
157+
128,
158+
])
159+
@pytest.mark.parametrize("batch_size", [4])
160+
@pytest.mark.parametrize("seed", [1])
161+
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
162+
test_llm_generator,
163+
batch_size: int,
164+
output_len: int):
165+
"""Verify greedy equality, even when some sequences are preempted mid-
166+
generation.
167+
"""
168+
run_greedy_equality_correctness_test(baseline_llm_generator,
169+
test_llm_generator,
170+
batch_size,
171+
max_output_len=output_len,
172+
force_output_len=True)
173+
174+
175+
@pytest.mark.parametrize(
176+
"common_llm_kwargs",
177+
[{
178+
# Skip cuda graph recording for fast test.
179+
"enforce_eager": True,
180+
181+
# Required for spec decode.
182+
"use_v2_block_manager": True,
183+
184+
# Precision
185+
"dtype": PRECISION,
186+
187+
# Main model
188+
"model": MAIN_MODEL,
189+
}])
190+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
191+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
192+
@pytest.mark.parametrize(
193+
"test_llm_kwargs",
194+
[
195+
{
196+
"speculative_model": SPEC_MODEL,
197+
"num_speculative_tokens": k,
198+
}
199+
# Try a range of num. speculative tokens
200+
for k in range(1, 1 + MAX_SPEC_TOKENS)
201+
])
202+
@pytest.mark.parametrize("batch_size", [2])
203+
@pytest.mark.parametrize(
204+
"output_len",
205+
[
206+
# Use smaller output len for fast test.
207+
32,
208+
])
209+
@pytest.mark.parametrize("seed", [1])
210+
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
211+
batch_size: int, output_len: int):
212+
"""Verify that eagle speculative decoding produces exact equality
213+
to without spec decode with different values of num_speculative_tokens.
214+
"""
215+
run_greedy_equality_correctness_test(baseline_llm_generator,
216+
test_llm_generator,
217+
batch_size,
218+
max_output_len=output_len,
219+
force_output_len=True)
220+
221+
222+
@pytest.mark.parametrize(
223+
"common_llm_kwargs",
224+
[{
225+
# Skip cuda graph recording for fast test.
226+
"enforce_eager": True,
227+
228+
# Required for spec decode.
229+
"use_v2_block_manager": True,
230+
231+
# Precision
232+
"dtype": PRECISION,
233+
234+
# Main model
235+
"model": MAIN_MODEL,
236+
}])
237+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
238+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
239+
@pytest.mark.parametrize("test_llm_kwargs",
240+
[{
241+
"speculative_model": SPEC_MODEL,
242+
"num_speculative_tokens": MAX_SPEC_TOKENS,
243+
"speculative_disable_by_batch_size": 4
244+
}])
245+
@pytest.mark.parametrize("batch_size", [1, 5])
246+
@pytest.mark.parametrize(
247+
"output_len",
248+
[
249+
# Use smaller output len for fast test.
250+
32,
251+
])
252+
@pytest.mark.parametrize("seed", [1])
253+
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
254+
batch_size: int, output_len: int):
255+
"""Verify that eagle speculative decoding produces exact equality
256+
to without spec decode when speculation is disabled for large
257+
batch sizes.
258+
"""
259+
run_greedy_equality_correctness_test(baseline_llm_generator,
260+
test_llm_generator,
261+
batch_size,
262+
max_output_len=output_len,
263+
force_output_len=True)
264+
265+
266+
if __name__ == "__main__":
267+
import pytest
268+
pytest.main([__file__])

tests/spec_decode/e2e/test_medusa_correctness.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@
7070
])
7171
@pytest.mark.parametrize("batch_size", [1, 32])
7272
@pytest.mark.parametrize("seed", [1])
73-
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
74-
batch_size: int, output_len: int):
73+
def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
74+
test_llm_generator, batch_size: int,
75+
output_len: int):
7576
"""Verify greedy equality with different batch size."""
7677
run_greedy_equality_correctness_test(baseline_llm_generator,
7778
test_llm_generator,
@@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
8081
force_output_len=True)
8182

8283

84+
@pytest.mark.parametrize(
85+
"common_llm_kwargs",
86+
[{
87+
"enforce_eager": False,
88+
89+
# Required for spec decode.
90+
"use_v2_block_manager": True,
91+
92+
# Print spec metrics.
93+
"disable_log_stats": False,
94+
95+
# Precision
96+
"dtype": PRECISION,
97+
98+
# Main model
99+
"model": MAIN_MODEL,
100+
}])
101+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
102+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
103+
@pytest.mark.parametrize("test_llm_kwargs", [
104+
{
105+
"speculative_model": SPEC_MODEL,
106+
"num_speculative_tokens": MAX_SPEC_TOKENS,
107+
},
108+
])
109+
@pytest.mark.parametrize("output_len", [
110+
128,
111+
])
112+
@pytest.mark.parametrize("batch_size", [1, 32])
113+
@pytest.mark.parametrize("seed", [1])
114+
def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
115+
test_llm_generator,
116+
batch_size: int,
117+
output_len: int):
118+
"""Verify greedy equality with cuda graph enabled and different
119+
batch sizes."""
120+
run_greedy_equality_correctness_test(baseline_llm_generator,
121+
test_llm_generator,
122+
batch_size,
123+
max_output_len=output_len,
124+
force_output_len=True)
125+
126+
83127
@pytest.mark.parametrize(
84128
"common_llm_kwargs",
85129
[{
@@ -116,10 +160,10 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
116160
])
117161
@pytest.mark.parametrize("batch_size", [4])
118162
@pytest.mark.parametrize("seed", [1])
119-
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
120-
test_llm_generator,
121-
batch_size: int,
122-
output_len: int):
163+
def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
164+
test_llm_generator,
165+
batch_size: int,
166+
output_len: int):
123167
"""Verify greedy equality, even when some sequences are preempted mid-
124168
generation.
125169
"""
@@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
165209
32,
166210
])
167211
@pytest.mark.parametrize("seed", [1])
168-
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
169-
batch_size: int, output_len: int):
170-
"""Verify that mlp speculative decoding produces exact equality
212+
def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
213+
batch_size: int, output_len: int):
214+
"""Verify that medusa speculative decoding produces exact equality
171215
to without spec decode with different values of num_speculative_tokens.
172216
"""
173217
run_greedy_equality_correctness_test(baseline_llm_generator,
@@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
208252
32,
209253
])
210254
@pytest.mark.parametrize("seed", [1])
211-
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
212-
batch_size: int, output_len: int):
213-
"""Verify that mlp speculative decoding produces exact equality
255+
def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator,
256+
batch_size: int, output_len: int):
257+
"""Verify that medusa speculative decoding produces exact equality
214258
to without spec decode when speculation is disabled for large
215259
batch sizes.
216260
"""

0 commit comments

Comments
 (0)