@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
44
44
["all_tokens_accepted" , "no_tokens_accepted" , "some_tokens_accepted" ])
45
45
@pytest .mark .parametrize ("disable_bonus_tokens" , [True , False ])
46
46
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
47
+ @pytest .mark .parametrize ("use_flashinfer" , [True , False ])
47
48
@torch .inference_mode ()
48
- def test_correct_output_format (which_tokens_accepted : str ,
49
- disable_bonus_tokens : bool , seed : int ,
50
- device : str ):
49
+ def test_correct_output_format (which_tokens_accepted : str , seed : int ,
50
+ disable_bonus_tokens : bool , device : str ,
51
+ use_flashinfer : bool ):
51
52
"""Verify the output has correct format given predetermined accepted matrix.
52
53
"""
54
+ if use_flashinfer and disable_bonus_tokens :
55
+ pytest .skip ("Flashinfer rejection sampler must enable bonus token." )
56
+
53
57
set_random_seed (seed )
54
58
torch .set_default_device (device )
55
59
@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
85
89
dtype = torch .int64 )
86
90
87
91
rejection_sampler = RejectionSampler (
88
- disable_bonus_tokens = disable_bonus_tokens )
92
+ disable_bonus_tokens = disable_bonus_tokens ,
93
+ use_flashinfer = use_flashinfer )
89
94
rejection_sampler .init_gpu_tensors (device = device )
90
95
output_token_ids = rejection_sampler ._create_output ( # pylint: disable=protected-access
91
96
accepted ,
@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
133
138
@pytest .mark .parametrize ("vocab_size" , [30_000 , 50_000 ])
134
139
@pytest .mark .parametrize ("batch_size" , list (range (1 , 32 )))
135
140
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
141
+ @pytest .mark .parametrize ("use_flashinfer" , [True , False ])
136
142
@torch .inference_mode ()
137
143
def test_no_crash_with_varying_dims (k : int , vocab_size : int , batch_size : int ,
138
- device : str ):
144
+ device : str , use_flashinfer : bool ):
139
145
torch .set_default_device (device )
140
- rejection_sampler = RejectionSampler ()
146
+ rejection_sampler = RejectionSampler (disable_bonus_tokens = False ,
147
+ use_flashinfer = use_flashinfer )
141
148
rejection_sampler .init_gpu_tensors (device = device )
142
149
143
150
draft_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
144
- target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
151
+ target_probs = torch .rand (batch_size ,
152
+ k + 1 ,
153
+ vocab_size ,
154
+ dtype = torch .float32 )
145
155
bonus_token_ids = torch .randint (low = 0 ,
146
156
high = vocab_size ,
147
157
size = (batch_size , 1 ),
@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
161
171
@pytest .mark .parametrize ("batch_size" , [1 , 8 , 32 , 128 ])
162
172
@pytest .mark .parametrize ("n_rep" , [100 ])
163
173
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
174
+ @pytest .mark .parametrize ("use_flashinfer" , [True , False ])
164
175
@torch .inference_mode ()
165
176
def test_deterministic_when_seeded (k : int , vocab_size : int , batch_size : int ,
166
- frac_seeded : float , n_rep : int ,
167
- device : str ):
177
+ frac_seeded : float , n_rep : int , device : str ,
178
+ use_flashinfer : bool ):
168
179
torch .set_default_device (device )
169
- rejection_sampler = RejectionSampler ()
180
+ rejection_sampler = RejectionSampler (disable_bonus_tokens = False ,
181
+ use_flashinfer = use_flashinfer )
170
182
rejection_sampler .init_gpu_tensors (device = device )
171
183
172
184
draft_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
173
- target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
185
+ target_probs = torch .rand (batch_size ,
186
+ k + 1 ,
187
+ vocab_size ,
188
+ dtype = torch .float32 )
174
189
bonus_token_ids = torch .randint (low = 0 ,
175
190
high = vocab_size ,
176
191
size = (batch_size , 1 ),
@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
198
213
assert torch .equal (results [j ][i ], results [0 ][i ])
199
214
200
215
216
+ @pytest .mark .parametrize ("k" , [1 , 3 , 6 ])
217
+ @pytest .mark .parametrize ("vocab_size" , [30_000 , 50_000 ])
218
+ @pytest .mark .parametrize ("batch_size" , [1 , 8 , 32 , 128 ])
219
+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
220
+ @torch .inference_mode ()
221
+ def test_compare_nonflashinfer_backend (k : int , vocab_size : int ,
222
+ batch_size : int , device : str ):
223
+ """
224
+ Test the flashinfer and nonflashinfer backend generate
225
+ the same output metrics.
226
+ """
227
+ torch .set_default_device (device )
228
+ torch .manual_seed (0 )
229
+ draft_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
230
+ target_probs = torch .rand (batch_size ,
231
+ k + 1 ,
232
+ vocab_size ,
233
+ dtype = torch .float32 )
234
+ bonus_token_ids = torch .randint (low = 0 ,
235
+ high = vocab_size ,
236
+ size = (batch_size , 1 ),
237
+ dtype = torch .int64 )
238
+ draft_token_ids = torch .randint (low = 0 ,
239
+ high = vocab_size ,
240
+ size = (batch_size , k ),
241
+ dtype = torch .int64 )
242
+
243
+ num_accepted_tokens = []
244
+ num_emitted_tokens = []
245
+ num_draft_tokens = []
246
+
247
+ def get_seeded_seqs ():
248
+ return {
249
+ i : torch .Generator (device = device ).manual_seed (i )
250
+ for i in range (batch_size )
251
+ }
252
+
253
+ for use_flashinfer in [True , False ]:
254
+ rejection_sampler = RejectionSampler (disable_bonus_tokens = False ,
255
+ use_flashinfer = use_flashinfer )
256
+ rejection_sampler .init_gpu_tensors (device = device )
257
+ # We use seeded sequences to ensure the same tokens are accepted
258
+ # for both flashinfer and nonflashinfer backends.
259
+ seeded_seqs = get_seeded_seqs ()
260
+ rejection_sampler (target_probs , bonus_token_ids , draft_probs ,
261
+ draft_token_ids , seeded_seqs )
262
+ num_accepted_tokens .append (rejection_sampler .num_accepted_tokens )
263
+ num_emitted_tokens .append (rejection_sampler .num_emitted_tokens )
264
+ num_draft_tokens .append (rejection_sampler .num_draft_tokens )
265
+
266
+ assert num_accepted_tokens [0 ] == num_accepted_tokens [1 ]
267
+ assert num_emitted_tokens [0 ] == num_emitted_tokens [1 ]
268
+ assert num_draft_tokens [0 ] == num_draft_tokens [1 ]
269
+
270
+
201
271
@pytest .mark .parametrize ("above_or_below_vocab_range" , ["above" , "below" ])
202
272
@pytest .mark .parametrize ("which_token_ids" ,
203
273
["bonus_token_ids" , "draft_token_ids" ])
204
274
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
275
+ @pytest .mark .parametrize ("use_flashinfer" , [True , False ])
205
276
@torch .inference_mode ()
206
277
def test_raises_when_vocab_oob (above_or_below_vocab_range : str ,
207
- which_token_ids : str , device : str ):
278
+ which_token_ids : str , device : str ,
279
+ use_flashinfer : bool ):
208
280
k = 3
209
281
batch_size = 5
210
282
vocab_size = 30_000
211
283
torch .set_default_device (device )
212
284
213
- rejection_sampler = RejectionSampler (strict_mode = True )
285
+ rejection_sampler = RejectionSampler (disable_bonus_tokens = False ,
286
+ use_flashinfer = use_flashinfer ,
287
+ strict_mode = True )
214
288
rejection_sampler .init_gpu_tensors (device = device )
215
289
216
290
draft_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
217
- target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
291
+ target_probs = torch .rand (batch_size ,
292
+ k + 1 ,
293
+ vocab_size ,
294
+ dtype = torch .float32 )
218
295
bonus_token_ids = torch .randint (low = 0 ,
219
296
high = vocab_size ,
220
297
size = (batch_size , 1 ),
@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
248
325
249
326
@pytest .mark .parametrize ("draft_and_target_probs_equal" , [True , False ])
250
327
@pytest .mark .parametrize ("seed" , list (range (5 )))
328
+ @pytest .mark .parametrize ("use_flashinfer" , [True , False ])
251
329
@torch .inference_mode ()
252
330
def test_rejection_sampling_approximates_target_distribution (
253
- seed : int , draft_and_target_probs_equal : bool ):
331
+ seed : int , draft_and_target_probs_equal : bool , use_flashinfer : bool ):
254
332
"""Verify rejection sampling approximates target distribution,
255
333
despite sampling from a potentially distinct draft distribution.
256
334
@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
279
357
"""
280
358
torch .set_default_device ("cpu" )
281
359
set_random_seed (seed )
282
-
283
360
helper = _CorrectnessTestHelper (
284
361
vocab_size = 10 ,
285
- rejection_sampler = RejectionSampler (),
362
+ rejection_sampler = RejectionSampler (disable_bonus_tokens = False ,
363
+ use_flashinfer = use_flashinfer ),
286
364
)
287
365
288
366
draft_probs , target_probs , reference_probs = helper .generate_probs_for_test (
@@ -398,10 +476,10 @@ def _estimate_rejection_sampling_pdf(
398
476
draft_probs = draft_probs .reshape (1 , self .k , self .vocab_size ).repeat (
399
477
num_samples , 1 , 1 )
400
478
401
- # Repeat target probs num_samples * k times.
479
+ # Repeat target probs num_samples * (k + 1) times.
402
480
# Rejection sampler requires bonus token probs, but they aren't used.
403
481
target_probs = target_probs .reshape (1 , 1 , self .vocab_size ).repeat (
404
- num_samples , self .k , 1 )
482
+ num_samples , self .k + 1 , 1 )
405
483
406
484
# Randomly sample draft token ids from draft probs.
407
485
draft_token_ids = torch .multinomial (draft_probs [:, 0 , :],
0 commit comments