44import random
55
66import pytest
7- import torch
87
98from vllm import LLM , SamplingParams
109from vllm .config import CompilationConfig , CompilationMode
11- from vllm .distributed import cleanup_dist_env_and_memory
1210
13- from ...utils import fork_new_process_for_each_test
11+ from ...utils import check_answers , fork_new_process_for_each_test , prep_prompts
1412
1513# global seed
1614SEED = 42
@@ -45,28 +43,12 @@ def test_prompts():
4543 return prompts
4644
4745
48- def cleanup (llm : LLM , compilation_config : CompilationConfig ):
49- # hacky: below lines are required to free up memory for the next test
50- # when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
51- # TODO(sarckk): when enforce_eager=False, memory is not freed:
52- # find out why and re-enable test for enforce_eager=False case
53- llm_engine = llm .llm_engine .engine_core .engine_core
54- model_runner = llm_engine .model_executor .driver_worker .worker .model_runner
55- del model_runner .model
56- del model_runner .kv_caches
57- del compilation_config .static_forward_context
58- compilation_config .static_forward_context = {}
59-
60- del llm
61- torch .cuda .empty_cache ()
62- cleanup_dist_env_and_memory ()
63-
64-
6546@fork_new_process_for_each_test
66- @pytest .mark .parametrize ("enforce_eager " , [True ])
67- @pytest .mark .skip ( reason = "Disable until Gemma3n supports fast prefill" )
47+ @pytest .mark .parametrize ("kv_sharing_fast_prefill " , [False , True ])
48+ @pytest .mark .parametrize ( "enforce_eager" , [ True , False ] )
6849def test_kv_sharing_fast_prefill (
6950 monkeypatch : pytest .MonkeyPatch ,
51+ kv_sharing_fast_prefill : bool ,
7052 enforce_eager : bool ,
7153 test_prompts : list [str ],
7254):
@@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
7961 if not enforce_eager
8062 else CompilationMode .NONE ,
8163 )
64+ batch_size = 10
8265
8366 with monkeypatch .context () as m :
8467 # Make scheduling deterministic for reproducibility
8568 m .setenv ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
8669
87- llm = LLM (
88- model = "google/gemma-3n-E2B-it" ,
89- enforce_eager = enforce_eager ,
90- compilation_config = compilation_config ,
91- seed = SEED ,
92- )
93- ref_responses = llm .generate (test_prompts , sampling_params )
94-
95- cleanup (llm , compilation_config )
70+ prompts , answer , indices = prep_prompts (batch_size )
9671
9772 llm = LLM (
9873 model = "google/gemma-3n-E2B-it" ,
9974 enforce_eager = enforce_eager ,
10075 compilation_config = compilation_config ,
10176 seed = SEED ,
102- kv_sharing_fast_prefill = True ,
77+ kv_sharing_fast_prefill = kv_sharing_fast_prefill ,
78+ )
79+ responses = llm .generate (prompts , sampling_params )
80+ check_answers (
81+ indices ,
82+ answer ,
83+ [response .outputs [0 ].text for response in responses ],
84+ accept_rate = 1.0 ,
10385 )
104- optimized_responses = llm .generate (test_prompts , sampling_params )
105-
106- cleanup (llm , compilation_config )
107-
108- misses = 0
109-
110- for ref_response , optimized_response in zip (ref_responses , optimized_responses ):
111- if ref_response .outputs [0 ].text != optimized_response .outputs [0 ].text :
112- misses += 1
113-
114- assert misses == 0
0 commit comments