@@ -100,21 +100,19 @@ def test_models(
100
100
else :
101
101
hf_outputs = None
102
102
103
- if model not in V0_UNSUPPORTED_MODELS :
104
- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
105
- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
106
- example_prompts , max_tokens , num_logprobs )
107
- else :
108
- vllm_v0_outputs = None
103
+ with monkeypatch .context () as m :
104
+ m .setenv ("VLLM_USE_V1" , "0" )
105
+ if model not in V0_UNSUPPORTED_MODELS :
106
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
107
+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
108
+ example_prompts , max_tokens , num_logprobs )
109
+ else :
110
+ vllm_v0_outputs = None
109
111
110
112
if model in V1_SUPPORTED_MODELS :
111
- with monkeypatch .context () as m :
112
- m .setenv ("VLLM_USE_V1" , "1" )
113
- with vllm_runner (model ,
114
- max_num_seqs = MAX_NUM_SEQS ,
115
- enable_prefix_caching = False ) as vllm_model :
116
- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
117
- example_prompts , max_tokens , num_logprobs )
113
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
114
+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
115
+ example_prompts , max_tokens , num_logprobs )
118
116
else :
119
117
vllm_v1_outputs = None
120
118
@@ -137,7 +135,7 @@ def test_models(
137
135
)
138
136
139
137
140
- @pytest .mark .parametrize ("model" , SSM_MODELS + HYBRID_MODELS )
138
+ @pytest .mark .parametrize ("model" , [ SSM_MODELS [ 0 ], HYBRID_MODELS [ 0 ]] )
141
139
@pytest .mark .parametrize ("max_tokens" , [64 ])
142
140
@pytest .mark .parametrize ("num_logprobs" , [5 ])
143
141
def test_batching (
@@ -147,10 +145,6 @@ def test_batching(
147
145
max_tokens : int ,
148
146
num_logprobs : int ,
149
147
) -> None :
150
- if model in V0_UNSUPPORTED_MODELS :
151
- pytest .skip (
152
- f"Unsupported V0 Engine. Skipping `test_batching` on { model } ." )
153
-
154
148
try :
155
149
model_info = HF_EXAMPLE_MODELS .find_hf_info (model )
156
150
model_info .check_available_online (on_fail = "skip" )
@@ -188,29 +182,32 @@ def test_chunked_prefill(
188
182
max_tokens : int ,
189
183
num_logprobs : int ,
190
184
chunked_prefill_token_size : int ,
185
+ monkeypatch ,
191
186
) -> None :
192
187
max_num_seqs = chunked_prefill_token_size
193
188
max_num_batched_tokens = chunked_prefill_token_size
194
189
195
- with vllm_runner (model ,
196
- enable_chunked_prefill = True ,
197
- max_num_batched_tokens = max_num_batched_tokens ,
198
- max_num_seqs = max_num_seqs ) as vllm_model :
199
- chunked = vllm_model .generate_greedy_logprobs (example_prompts ,
200
- max_tokens , num_logprobs )
190
+ with monkeypatch .context () as m :
191
+ m .setenv ("VLLM_USE_V1" , "0" )
192
+ with vllm_runner (model ,
193
+ enable_chunked_prefill = True ,
194
+ max_num_batched_tokens = max_num_batched_tokens ,
195
+ max_num_seqs = max_num_seqs ) as vllm_model :
196
+ chunked = vllm_model .generate_greedy_logprobs (
197
+ example_prompts , max_tokens , num_logprobs )
201
198
202
- with vllm_runner (model ,
203
- enable_chunked_prefill = False ,
204
- max_num_seqs = max_num_seqs ) as vllm_model :
205
- non_chunked = vllm_model .generate_greedy_logprobs (
206
- example_prompts , max_tokens , num_logprobs )
199
+ with vllm_runner (model ,
200
+ enable_chunked_prefill = False ,
201
+ max_num_seqs = max_num_seqs ) as vllm_model :
202
+ non_chunked = vllm_model .generate_greedy_logprobs (
203
+ example_prompts , max_tokens , num_logprobs )
207
204
208
- check_logprobs_close (
209
- outputs_0_lst = chunked ,
210
- outputs_1_lst = non_chunked ,
211
- name_0 = "chunked" ,
212
- name_1 = "non_chunked" ,
213
- )
205
+ check_logprobs_close (
206
+ outputs_0_lst = chunked ,
207
+ outputs_1_lst = non_chunked ,
208
+ name_0 = "chunked" ,
209
+ name_1 = "non_chunked" ,
210
+ )
214
211
215
212
216
213
@pytest .mark .parametrize ("model" , [SSM_MODELS [0 ], HYBRID_MODELS [0 ]])
@@ -281,25 +278,29 @@ def test_models_preemption_recompute(
281
278
example_prompts ,
282
279
model : str ,
283
280
max_tokens : int ,
281
+ monkeypatch ,
284
282
) -> None :
285
283
"""
286
284
Tests that outputs are identical with and w/o preemptions (recompute).
287
285
"""
288
- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
289
- scheduler = vllm_model .llm .llm_engine .scheduler [0 ]
290
- scheduler .ENABLE_ARTIFICIAL_PREEMPT = True
291
- preempt_vllm_outputs = vllm_model .generate_greedy (
292
- example_prompts , max_tokens )
293
-
294
- scheduler .ENABLE_ARTIFICIAL_PREEMPT = False
295
- vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
296
-
297
- check_outputs_equal (
298
- outputs_0_lst = preempt_vllm_outputs ,
299
- outputs_1_lst = vllm_outputs ,
300
- name_0 = "vllm_preepmtions" ,
301
- name_1 = "vllm" ,
302
- )
286
+ with monkeypatch .context () as m :
287
+ m .setenv ("VLLM_USE_V1" , "0" )
288
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
289
+ scheduler = vllm_model .llm .llm_engine .scheduler [0 ]
290
+ scheduler .ENABLE_ARTIFICIAL_PREEMPT = True
291
+ preempt_vllm_outputs = vllm_model .generate_greedy (
292
+ example_prompts , max_tokens )
293
+
294
+ scheduler .ENABLE_ARTIFICIAL_PREEMPT = False
295
+ vllm_outputs = vllm_model .generate_greedy (example_prompts ,
296
+ max_tokens )
297
+
298
+ check_outputs_equal (
299
+ outputs_0_lst = preempt_vllm_outputs ,
300
+ outputs_1_lst = vllm_outputs ,
301
+ name_0 = "vllm_preepmtions" ,
302
+ name_1 = "vllm" ,
303
+ )
303
304
304
305
305
306
@pytest .mark .parametrize ("model" , [SSM_MODELS [0 ], HYBRID_MODELS [0 ]])
@@ -402,24 +403,18 @@ def test_full_cuda_graph(
402
403
else :
403
404
hf_outputs = None
404
405
405
- if model not in V0_UNSUPPORTED_MODELS :
406
- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
407
- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
408
- example_prompts , max_tokens , num_logprobs )
409
- else :
410
- vllm_v0_outputs = None
411
-
412
406
with monkeypatch .context () as m :
413
- m .setenv ("VLLM_USE_V1" , "1" )
414
- if model in HYBRID_MODELS :
415
- # required due to reorder_batch behaviour
416
- m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
417
- with vllm_runner (model ,
418
- max_num_seqs = MAX_NUM_SEQS ,
419
- compilation_config = {'full_cuda_graph' : True },
420
- enable_prefix_caching = False ) as vllm_model :
421
- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
422
- example_prompts , max_tokens , num_logprobs )
407
+ m .setenv ("VLLM_USE_V1" , "0" )
408
+ if model not in V0_UNSUPPORTED_MODELS :
409
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
410
+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
411
+ example_prompts , max_tokens , num_logprobs )
412
+ else :
413
+ vllm_v0_outputs = None
414
+
415
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
416
+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
417
+ example_prompts , max_tokens , num_logprobs )
423
418
424
419
if hf_outputs is not None and vllm_v0_outputs is not None :
425
420
check_logprobs_close (
@@ -466,24 +461,20 @@ def test_fp32_state(
466
461
else :
467
462
hf_outputs = None
468
463
469
- with vllm_runner (model ,
470
- max_num_seqs = MAX_NUM_SEQS ,
471
- mamba_ssm_cache_dtype = "float32" ) as vllm_model :
472
- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
473
- example_prompts , max_tokens , num_logprobs )
474
-
475
464
with monkeypatch .context () as m :
476
- m .setenv ("VLLM_USE_V1" , "1" )
477
- if model in HYBRID_MODELS :
478
- # required due to reorder_batch behaviour
479
- m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
465
+ m .setenv ("VLLM_USE_V1" , "0" )
480
466
with vllm_runner (model ,
481
467
max_num_seqs = MAX_NUM_SEQS ,
482
- mamba_ssm_cache_dtype = "float32" ,
483
- enable_prefix_caching = False ) as vllm_model :
484
- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
468
+ mamba_ssm_cache_dtype = "float32" ) as vllm_model :
469
+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
485
470
example_prompts , max_tokens , num_logprobs )
486
471
472
+ with vllm_runner (model ,
473
+ max_num_seqs = MAX_NUM_SEQS ,
474
+ mamba_ssm_cache_dtype = "float32" ) as vllm_model :
475
+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
476
+ example_prompts , max_tokens , num_logprobs )
477
+
487
478
if hf_outputs is not None :
488
479
check_logprobs_close (
489
480
outputs_0_lst = hf_outputs ,
0 commit comments