Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit e6a26ed

Browse files
[SpecDecode][Kernel] Flashinfer Rejection Sampling (vllm-project#7244)
1 parent f8d6014 commit e6a26ed

File tree

9 files changed

+306
-109
lines changed

9 files changed

+306
-109
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
162162

163163
RUN --mount=type=cache,target=/root/.cache/pip \
164164
. /etc/environment && \
165-
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
165+
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
166166
#################### vLLM installation IMAGE ####################
167167

168168

tests/samplers/test_rejection_sampler.py

Lines changed: 97 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
4444
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
4545
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
4646
@pytest.mark.parametrize("device", CUDA_DEVICES)
47+
@pytest.mark.parametrize("use_flashinfer", [True, False])
4748
@torch.inference_mode()
48-
def test_correct_output_format(which_tokens_accepted: str,
49-
disable_bonus_tokens: bool, seed: int,
50-
device: str):
49+
def test_correct_output_format(which_tokens_accepted: str, seed: int,
50+
disable_bonus_tokens: bool, device: str,
51+
use_flashinfer: bool):
5152
"""Verify the output has correct format given predetermined accepted matrix.
5253
"""
54+
if use_flashinfer and disable_bonus_tokens:
55+
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
56+
5357
set_random_seed(seed)
5458
torch.set_default_device(device)
5559

@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
8589
dtype=torch.int64)
8690

8791
rejection_sampler = RejectionSampler(
88-
disable_bonus_tokens=disable_bonus_tokens)
92+
disable_bonus_tokens=disable_bonus_tokens,
93+
use_flashinfer=use_flashinfer)
8994
rejection_sampler.init_gpu_tensors(device=device)
9095
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
9196
accepted,
@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
133138
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
134139
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
135140
@pytest.mark.parametrize("device", CUDA_DEVICES)
141+
@pytest.mark.parametrize("use_flashinfer", [True, False])
136142
@torch.inference_mode()
137143
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
138-
device: str):
144+
device: str, use_flashinfer: bool):
139145
torch.set_default_device(device)
140-
rejection_sampler = RejectionSampler()
146+
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
147+
use_flashinfer=use_flashinfer)
141148
rejection_sampler.init_gpu_tensors(device=device)
142149

143150
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
144-
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
151+
target_probs = torch.rand(batch_size,
152+
k + 1,
153+
vocab_size,
154+
dtype=torch.float32)
145155
bonus_token_ids = torch.randint(low=0,
146156
high=vocab_size,
147157
size=(batch_size, 1),
@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
161171
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
162172
@pytest.mark.parametrize("n_rep", [100])
163173
@pytest.mark.parametrize("device", CUDA_DEVICES)
174+
@pytest.mark.parametrize("use_flashinfer", [True, False])
164175
@torch.inference_mode()
165176
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
166-
frac_seeded: float, n_rep: int,
167-
device: str):
177+
frac_seeded: float, n_rep: int, device: str,
178+
use_flashinfer: bool):
168179
torch.set_default_device(device)
169-
rejection_sampler = RejectionSampler()
180+
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
181+
use_flashinfer=use_flashinfer)
170182
rejection_sampler.init_gpu_tensors(device=device)
171183

172184
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
173-
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
185+
target_probs = torch.rand(batch_size,
186+
k + 1,
187+
vocab_size,
188+
dtype=torch.float32)
174189
bonus_token_ids = torch.randint(low=0,
175190
high=vocab_size,
176191
size=(batch_size, 1),
@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
198213
assert torch.equal(results[j][i], results[0][i])
199214

200215

216+
@pytest.mark.parametrize("k", [1, 3, 6])
217+
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
218+
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
219+
@pytest.mark.parametrize("device", CUDA_DEVICES)
220+
@torch.inference_mode()
221+
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
222+
batch_size: int, device: str):
223+
"""
224+
Test the flashinfer and nonflashinfer backend generate
225+
the same output metrics.
226+
"""
227+
torch.set_default_device(device)
228+
torch.manual_seed(0)
229+
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
230+
target_probs = torch.rand(batch_size,
231+
k + 1,
232+
vocab_size,
233+
dtype=torch.float32)
234+
bonus_token_ids = torch.randint(low=0,
235+
high=vocab_size,
236+
size=(batch_size, 1),
237+
dtype=torch.int64)
238+
draft_token_ids = torch.randint(low=0,
239+
high=vocab_size,
240+
size=(batch_size, k),
241+
dtype=torch.int64)
242+
243+
num_accepted_tokens = []
244+
num_emitted_tokens = []
245+
num_draft_tokens = []
246+
247+
def get_seeded_seqs():
248+
return {
249+
i: torch.Generator(device=device).manual_seed(i)
250+
for i in range(batch_size)
251+
}
252+
253+
for use_flashinfer in [True, False]:
254+
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
255+
use_flashinfer=use_flashinfer)
256+
rejection_sampler.init_gpu_tensors(device=device)
257+
# We use seeded sequences to ensure the same tokens are accepted
258+
# for both flashinfer and nonflashinfer backends.
259+
seeded_seqs = get_seeded_seqs()
260+
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
261+
draft_token_ids, seeded_seqs)
262+
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
263+
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
264+
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
265+
266+
assert num_accepted_tokens[0] == num_accepted_tokens[1]
267+
assert num_emitted_tokens[0] == num_emitted_tokens[1]
268+
assert num_draft_tokens[0] == num_draft_tokens[1]
269+
270+
201271
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
202272
@pytest.mark.parametrize("which_token_ids",
203273
["bonus_token_ids", "draft_token_ids"])
204274
@pytest.mark.parametrize("device", CUDA_DEVICES)
275+
@pytest.mark.parametrize("use_flashinfer", [True, False])
205276
@torch.inference_mode()
206277
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
207-
which_token_ids: str, device: str):
278+
which_token_ids: str, device: str,
279+
use_flashinfer: bool):
208280
k = 3
209281
batch_size = 5
210282
vocab_size = 30_000
211283
torch.set_default_device(device)
212284

213-
rejection_sampler = RejectionSampler(strict_mode=True)
285+
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
286+
use_flashinfer=use_flashinfer,
287+
strict_mode=True)
214288
rejection_sampler.init_gpu_tensors(device=device)
215289

216290
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
217-
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
291+
target_probs = torch.rand(batch_size,
292+
k + 1,
293+
vocab_size,
294+
dtype=torch.float32)
218295
bonus_token_ids = torch.randint(low=0,
219296
high=vocab_size,
220297
size=(batch_size, 1),
@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
248325

249326
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
250327
@pytest.mark.parametrize("seed", list(range(5)))
328+
@pytest.mark.parametrize("use_flashinfer", [True, False])
251329
@torch.inference_mode()
252330
def test_rejection_sampling_approximates_target_distribution(
253-
seed: int, draft_and_target_probs_equal: bool):
331+
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
254332
"""Verify rejection sampling approximates target distribution,
255333
despite sampling from a potentially distinct draft distribution.
256334
@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
279357
"""
280358
torch.set_default_device("cpu")
281359
set_random_seed(seed)
282-
283360
helper = _CorrectnessTestHelper(
284361
vocab_size=10,
285-
rejection_sampler=RejectionSampler(),
362+
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
363+
use_flashinfer=use_flashinfer),
286364
)
287365

288366
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
@@ -398,10 +476,10 @@ def _estimate_rejection_sampling_pdf(
398476
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
399477
num_samples, 1, 1)
400478

401-
# Repeat target probs num_samples * k times.
479+
# Repeat target probs num_samples * (k + 1) times.
402480
# Rejection sampler requires bonus token probs, but they aren't used.
403481
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
404-
num_samples, self.k, 1)
482+
num_samples, self.k + 1, 1)
405483

406484
# Randomly sample draft token ids from draft probs.
407485
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],

tests/samplers/test_typical_acceptance_sampler.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
7979
torch.set_default_device(device)
8080
typical_acceptance_sampler = get_acceptance_sampler()
8181
typical_acceptance_sampler.init_gpu_tensors(device=device)
82-
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
82+
target_with_bonus_probs = torch.rand(batch_size,
83+
k + 1,
84+
vocab_size,
85+
dtype=torch.float32)
8386
bonus_token_ids = torch.randint(low=0,
8487
high=vocab_size,
8588
size=(batch_size, 1),
@@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
8992
size=(batch_size, k),
9093
dtype=torch.int64)
9194
# Verify that sampling succeeds for all cases.
92-
typical_acceptance_sampler(target_probs,
95+
typical_acceptance_sampler(target_with_bonus_probs,
9396
bonus_token_ids,
9497
draft_probs=None,
9598
draft_token_ids=draft_token_ids)
@@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
112115
torch.set_default_device(device)
113116
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
114117
typical_acceptance_sampler.init_gpu_tensors(device=device)
115-
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
118+
target_with_bonus_probs = torch.rand(batch_size,
119+
k + 1,
120+
vocab_size,
121+
dtype=torch.float32)
116122
bonus_token_ids = torch.randint(low=0,
117123
high=vocab_size,
118124
size=(batch_size, 1),
@@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
141147
oob_token_ids[0][0] = rogue_token_id
142148

143149
with pytest.raises(AssertionError):
144-
typical_acceptance_sampler(target_probs,
150+
typical_acceptance_sampler(target_with_bonus_probs,
145151
bonus_token_ids,
146152
draft_probs=None,
147153
draft_token_ids=draft_token_ids)
@@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
172178
typical_acceptance_sampler = get_acceptance_sampler(
173179
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
174180
typical_acceptance_sampler.init_gpu_tensors(device=device)
175-
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
181+
target_with_bonus_probs = torch.rand(batch_size,
182+
k + 1,
183+
vocab_size,
184+
dtype=torch.float32)
176185
draft_token_ids = torch.randint(low=0,
177186
high=vocab_size,
178187
size=(batch_size, k),
@@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
182191
size=(batch_size, 1),
183192
dtype=torch.int64)
184193
output_token_ids = typical_acceptance_sampler(
185-
target_probs,
194+
target_with_bonus_probs,
186195
bonus_token_ids,
187196
draft_probs=None,
188197
draft_token_ids=draft_token_ids)
@@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
229238
# Simulate temperature 0 probability distribution for target probabilities
230239
# and create target probabilities such that only 1 token id has
231240
# probability 1.0
232-
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
233-
batch_size, k, vocab_size)
241+
target_with_bonus_probs, zero_temperature_token_ids = \
242+
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
243+
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
234244
# Populate draft_token_ids such that they exclude the token_ids
235245
# with probability = 1.0
236246
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
@@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
245255
# fallback to the greedy sampling for selecting 1 token for each sequence.
246256
# Verify the same.
247257
output_token_ids = typical_acceptance_sampler(
248-
target_probs,
258+
target_with_bonus_probs,
249259
bonus_token_ids,
250260
draft_probs=None,
251261
draft_token_ids=draft_token_ids)
@@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
289299
# For sequences 0 and 2 set the distribution to a temperature
290300
# zero distribution. For sequences 1 and 3 set it to a uniform
291301
# distribution.
292-
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
293-
batch_size, k, vocab_size))
302+
target_with_bonus_probs, zero_temperature_token_ids = \
303+
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
304+
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
305+
target_probs = target_with_bonus_probs[:, :-1]
294306
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
295307
zero_temperature_token_ids)
296308
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
@@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
300312
size=(batch_size, 1),
301313
dtype=torch.int64)
302314
output_token_ids = typical_acceptance_sampler(
303-
target_probs,
315+
target_with_bonus_probs,
304316
bonus_token_ids,
305317
draft_probs=None,
306318
draft_token_ids=draft_token_ids)
@@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
356368
# Create a temperature zero target probability distribution and ensure
357369
# all draft token ids correspond to the tokens with 1.0 probability.
358370
# Verify that all of them are accepted.
359-
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
360-
batch_size, k, vocab_size))
371+
target_with_bonus_probs, zero_temperature_token_ids = \
372+
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
373+
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
361374
draft_token_ids = zero_temperature_token_ids
362375
bonus_token_ids = torch.randint(low=0,
363376
high=vocab_size,
364377
size=(batch_size, 1),
365378
dtype=torch.int64)
366379
output_token_ids = typical_acceptance_sampler(
367-
target_probs,
380+
target_with_bonus_probs,
368381
bonus_token_ids,
369382
draft_probs=None,
370383
draft_token_ids=draft_token_ids)
@@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
384397
draft_token_ids = torch.cat(
385398
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
386399
output_token_ids = typical_acceptance_sampler(
387-
target_probs,
400+
target_with_bonus_probs,
388401
bonus_token_ids,
389402
draft_probs=None,
390403
draft_token_ids=draft_token_ids)
@@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
421434
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
422435
# with probability = 1.0. Without any changes to the posterior thresholds
423436
# none of the draft tokens are accepted.
424-
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
425-
batch_size, k, vocab_size))
437+
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
438+
batch_size, k + 1, vocab_size)
439+
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
426440
target_probs[target_probs == 0] = 0.00001
427441
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
428442
zero_temperature_token_ids)

tests/spec_decode/test_spec_decode_worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
230230

231231
assert torch.equal(actual.bonus_token_ids,
232232
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
233-
assert torch.equal(
234-
actual.target_probs,
235-
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
233+
assert torch.equal(actual.target_with_bonus_probs,
234+
target_token_probs.reshape(batch_size, k + 1, -1))
236235
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
237236
assert torch.equal(actual.draft_probs, proposal_probs)
238237

vllm/envs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
VLLM_TRACE_FUNCTION: int = 0
3232
VLLM_ATTENTION_BACKEND: Optional[str] = None
3333
VLLM_USE_FLASHINFER_SAMPLER: bool = False
34+
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
3435
VLLM_PP_LAYER_PARTITION: Optional[str] = None
3536
VLLM_CPU_KVCACHE_SPACE: int = 0
3637
VLLM_CPU_OMP_THREADS_BIND: str = ""

0 commit comments

Comments
 (0)