@@ -99,21 +99,19 @@ def test_models(
99
99
else :
100
100
hf_outputs = None
101
101
102
- if model not in V0_UNSUPPORTED_MODELS :
103
- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
104
- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
105
- example_prompts , max_tokens , num_logprobs )
106
- else :
107
- vllm_v0_outputs = None
102
+ with monkeypatch .context () as m :
103
+ m .setenv ("VLLM_USE_V1" , "0" )
104
+ if model not in V0_UNSUPPORTED_MODELS :
105
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
106
+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
107
+ example_prompts , max_tokens , num_logprobs )
108
+ else :
109
+ vllm_v0_outputs = None
108
110
109
111
if model in V1_SUPPORTED_MODELS :
110
- with monkeypatch .context () as m :
111
- m .setenv ("VLLM_USE_V1" , "1" )
112
- with vllm_runner (model ,
113
- max_num_seqs = MAX_NUM_SEQS ,
114
- enable_prefix_caching = False ) as vllm_model :
115
- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
116
- example_prompts , max_tokens , num_logprobs )
112
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
113
+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
114
+ example_prompts , max_tokens , num_logprobs )
117
115
else :
118
116
vllm_v1_outputs = None
119
117
@@ -136,7 +134,7 @@ def test_models(
136
134
)
137
135
138
136
139
- @pytest .mark .parametrize ("model" , SSM_MODELS + HYBRID_MODELS )
137
+ @pytest .mark .parametrize ("model" , [ SSM_MODELS [ 0 ], HYBRID_MODELS [ 0 ]] )
140
138
@pytest .mark .parametrize ("max_tokens" , [64 ])
141
139
@pytest .mark .parametrize ("num_logprobs" , [5 ])
142
140
def test_batching (
@@ -146,10 +144,6 @@ def test_batching(
146
144
max_tokens : int ,
147
145
num_logprobs : int ,
148
146
) -> None :
149
- if model in V0_UNSUPPORTED_MODELS :
150
- pytest .skip (
151
- f"Unsupported V0 Engine. Skipping `test_batching` on { model } ." )
152
-
153
147
try :
154
148
model_info = HF_EXAMPLE_MODELS .find_hf_info (model )
155
149
model_info .check_available_online (on_fail = "skip" )
@@ -187,29 +181,32 @@ def test_chunked_prefill(
187
181
max_tokens : int ,
188
182
num_logprobs : int ,
189
183
chunked_prefill_token_size : int ,
184
+ monkeypatch ,
190
185
) -> None :
191
186
max_num_seqs = chunked_prefill_token_size
192
187
max_num_batched_tokens = chunked_prefill_token_size
193
188
194
- with vllm_runner (model ,
195
- enable_chunked_prefill = True ,
196
- max_num_batched_tokens = max_num_batched_tokens ,
197
- max_num_seqs = max_num_seqs ) as vllm_model :
198
- chunked = vllm_model .generate_greedy_logprobs (example_prompts ,
199
- max_tokens , num_logprobs )
189
+ with monkeypatch .context () as m :
190
+ m .setenv ("VLLM_USE_V1" , "0" )
191
+ with vllm_runner (model ,
192
+ enable_chunked_prefill = True ,
193
+ max_num_batched_tokens = max_num_batched_tokens ,
194
+ max_num_seqs = max_num_seqs ) as vllm_model :
195
+ chunked = vllm_model .generate_greedy_logprobs (
196
+ example_prompts , max_tokens , num_logprobs )
200
197
201
- with vllm_runner (model ,
202
- enable_chunked_prefill = False ,
203
- max_num_seqs = max_num_seqs ) as vllm_model :
204
- non_chunked = vllm_model .generate_greedy_logprobs (
205
- example_prompts , max_tokens , num_logprobs )
198
+ with vllm_runner (model ,
199
+ enable_chunked_prefill = False ,
200
+ max_num_seqs = max_num_seqs ) as vllm_model :
201
+ non_chunked = vllm_model .generate_greedy_logprobs (
202
+ example_prompts , max_tokens , num_logprobs )
206
203
207
- check_logprobs_close (
208
- outputs_0_lst = chunked ,
209
- outputs_1_lst = non_chunked ,
210
- name_0 = "chunked" ,
211
- name_1 = "non_chunked" ,
212
- )
204
+ check_logprobs_close (
205
+ outputs_0_lst = chunked ,
206
+ outputs_1_lst = non_chunked ,
207
+ name_0 = "chunked" ,
208
+ name_1 = "non_chunked" ,
209
+ )
213
210
214
211
215
212
@pytest .mark .parametrize ("model" , [SSM_MODELS [0 ], HYBRID_MODELS [0 ]])
@@ -280,25 +277,29 @@ def test_models_preemption_recompute(
280
277
example_prompts ,
281
278
model : str ,
282
279
max_tokens : int ,
280
+ monkeypatch ,
283
281
) -> None :
284
282
"""
285
283
Tests that outputs are identical with and w/o preemptions (recompute).
286
284
"""
287
- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
288
- scheduler = vllm_model .llm .llm_engine .scheduler [0 ]
289
- scheduler .ENABLE_ARTIFICIAL_PREEMPT = True
290
- preempt_vllm_outputs = vllm_model .generate_greedy (
291
- example_prompts , max_tokens )
292
-
293
- scheduler .ENABLE_ARTIFICIAL_PREEMPT = False
294
- vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
295
-
296
- check_outputs_equal (
297
- outputs_0_lst = preempt_vllm_outputs ,
298
- outputs_1_lst = vllm_outputs ,
299
- name_0 = "vllm_preepmtions" ,
300
- name_1 = "vllm" ,
301
- )
285
+ with monkeypatch .context () as m :
286
+ m .setenv ("VLLM_USE_V1" , "0" )
287
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
288
+ scheduler = vllm_model .llm .llm_engine .scheduler [0 ]
289
+ scheduler .ENABLE_ARTIFICIAL_PREEMPT = True
290
+ preempt_vllm_outputs = vllm_model .generate_greedy (
291
+ example_prompts , max_tokens )
292
+
293
+ scheduler .ENABLE_ARTIFICIAL_PREEMPT = False
294
+ vllm_outputs = vllm_model .generate_greedy (example_prompts ,
295
+ max_tokens )
296
+
297
+ check_outputs_equal (
298
+ outputs_0_lst = preempt_vllm_outputs ,
299
+ outputs_1_lst = vllm_outputs ,
300
+ name_0 = "vllm_preepmtions" ,
301
+ name_1 = "vllm" ,
302
+ )
302
303
303
304
304
305
@pytest .mark .parametrize ("model" , [SSM_MODELS [0 ], HYBRID_MODELS [0 ]])
@@ -401,24 +402,18 @@ def test_full_cuda_graph(
401
402
else :
402
403
hf_outputs = None
403
404
404
- if model not in V0_UNSUPPORTED_MODELS :
405
- with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
406
- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
407
- example_prompts , max_tokens , num_logprobs )
408
- else :
409
- vllm_v0_outputs = None
410
-
411
405
with monkeypatch .context () as m :
412
- m .setenv ("VLLM_USE_V1" , "1" )
413
- if model in HYBRID_MODELS :
414
- # required due to reorder_batch behaviour
415
- m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
416
- with vllm_runner (model ,
417
- max_num_seqs = MAX_NUM_SEQS ,
418
- compilation_config = {'full_cuda_graph' : True },
419
- enable_prefix_caching = False ) as vllm_model :
420
- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
421
- example_prompts , max_tokens , num_logprobs )
406
+ m .setenv ("VLLM_USE_V1" , "0" )
407
+ if model not in V0_UNSUPPORTED_MODELS :
408
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
409
+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
410
+ example_prompts , max_tokens , num_logprobs )
411
+ else :
412
+ vllm_v0_outputs = None
413
+
414
+ with vllm_runner (model , max_num_seqs = MAX_NUM_SEQS ) as vllm_model :
415
+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
416
+ example_prompts , max_tokens , num_logprobs )
422
417
423
418
if hf_outputs is not None and vllm_v0_outputs is not None :
424
419
check_logprobs_close (
@@ -465,24 +460,20 @@ def test_fp32_state(
465
460
else :
466
461
hf_outputs = None
467
462
468
- with vllm_runner (model ,
469
- max_num_seqs = MAX_NUM_SEQS ,
470
- mamba_ssm_cache_dtype = "float32" ) as vllm_model :
471
- vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
472
- example_prompts , max_tokens , num_logprobs )
473
-
474
463
with monkeypatch .context () as m :
475
- m .setenv ("VLLM_USE_V1" , "1" )
476
- if model in HYBRID_MODELS :
477
- # required due to reorder_batch behaviour
478
- m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
464
+ m .setenv ("VLLM_USE_V1" , "0" )
479
465
with vllm_runner (model ,
480
466
max_num_seqs = MAX_NUM_SEQS ,
481
- mamba_ssm_cache_dtype = "float32" ,
482
- enable_prefix_caching = False ) as vllm_model :
483
- vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
467
+ mamba_ssm_cache_dtype = "float32" ) as vllm_model :
468
+ vllm_v0_outputs = vllm_model .generate_greedy_logprobs (
484
469
example_prompts , max_tokens , num_logprobs )
485
470
471
+ with vllm_runner (model ,
472
+ max_num_seqs = MAX_NUM_SEQS ,
473
+ mamba_ssm_cache_dtype = "float32" ) as vllm_model :
474
+ vllm_v1_outputs = vllm_model .generate_greedy_logprobs (
475
+ example_prompts , max_tokens , num_logprobs )
476
+
486
477
if hf_outputs is not None :
487
478
check_logprobs_close (
488
479
outputs_0_lst = hf_outputs ,
0 commit comments