@@ -343,3 +343,78 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
343
343
b = baseline_rank_to_logprob [rank ],
344
344
abs_tol = 1e-1 ,
345
345
)
346
+
347
+
348
+ @pytest .mark .parametrize (
349
+ "common_llm_kwargs" ,
350
+ [{
351
+ "model" : "JackFram/llama-160m" ,
352
+ # Skip cuda graph recording for fast test.
353
+ "enforce_eager" : True ,
354
+ # Required for spec decode.
355
+ "use_v2_block_manager" : True ,
356
+ "max_logprobs" : 6 ,
357
+ }])
358
+ @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
359
+ @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
360
+ @pytest .mark .parametrize ("test_llm_kwargs" ,
361
+ [{
362
+ "speculative_model" : "JackFram/llama-68m" ,
363
+ "num_speculative_tokens" : 3 ,
364
+ "disable_logprobs_during_spec_decoding" : True ,
365
+ }])
366
+ @pytest .mark .parametrize ("seed" , [1 ])
367
+ def test_logprobs_disabled (baseline_llm_generator , test_llm_generator ):
368
+ """Check the behavior when logprobs are disabled.
369
+ Token choices should match with the base model.
370
+ """
371
+ prompts = [
372
+ "Hello, my name is" ,
373
+ "The president of the United States is" ,
374
+ "The capital of France is" ,
375
+ "The future of AI is" ,
376
+ "San Francisco is know for its" ,
377
+ "Facebook was created in 2004 by" ,
378
+ "Curious George is a" ,
379
+ "Python 3.11 brings improvements to its" ,
380
+ ]
381
+
382
+ prompts = [prompt for prompt , _ in zip (cycle (prompts ), range (4 ))]
383
+
384
+ sampling_params = SamplingParams (
385
+ # Use smaller output len for fast test
386
+ max_tokens = 7 ,
387
+ ignore_eos = True ,
388
+ temperature = 0.0 ,
389
+ logprobs = 2 ,
390
+ )
391
+
392
+ spec_batch_logprobs = get_logprobs_from_llm_generator (
393
+ test_llm_generator , prompts , sampling_params )
394
+ baseline_batch_logprobs = get_logprobs_from_llm_generator (
395
+ baseline_llm_generator , prompts , sampling_params )
396
+
397
+ assert len (baseline_batch_logprobs ) == len (prompts )
398
+ assert len (spec_batch_logprobs ) == len (prompts )
399
+
400
+ # For each sequence in the batch.
401
+ for _ , (baseline_logprobs , spec_logprobs ) in enumerate (
402
+ zip (baseline_batch_logprobs , spec_batch_logprobs )):
403
+ assert len (spec_logprobs ) == len (baseline_logprobs )
404
+
405
+ # For each generated position of the sequence.
406
+ for _ , (spec_pos_logprobs , baseline_pos_logprobs ) in enumerate (
407
+ zip (spec_logprobs , baseline_logprobs )):
408
+
409
+ assert len (spec_pos_logprobs ) == 1
410
+ spec_top_token_id = list (spec_pos_logprobs )[0 ]
411
+
412
+ spec_top_logprob = spec_pos_logprobs [spec_top_token_id ]
413
+ assert spec_top_logprob .logprob == 0.0
414
+ assert spec_top_logprob .rank == - 1
415
+
416
+ # check that the chosen token matches the base model
417
+ baseline_logprob = baseline_pos_logprobs [spec_top_token_id ]
418
+ assert baseline_logprob .rank == 1
419
+ assert spec_top_logprob .decoded_token \
420
+ == baseline_logprob .decoded_token
0 commit comments