Skip to content

Commit adb5a27

Browse files
Josephasafgnopperl
authored andcommitted
[V1][Mamba] - Enable V1 by default for Mamba Models (vllm-project#23650)
Signed-off-by: asafg <[email protected]>
1 parent fbd07c5 commit adb5a27

File tree

3 files changed

+72
-85
lines changed

3 files changed

+72
-85
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 71 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,19 @@ def test_models(
9999
else:
100100
hf_outputs = None
101101

102-
if model not in V0_UNSUPPORTED_MODELS:
103-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
104-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
105-
example_prompts, max_tokens, num_logprobs)
106-
else:
107-
vllm_v0_outputs = None
102+
with monkeypatch.context() as m:
103+
m.setenv("VLLM_USE_V1", "0")
104+
if model not in V0_UNSUPPORTED_MODELS:
105+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
106+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
107+
example_prompts, max_tokens, num_logprobs)
108+
else:
109+
vllm_v0_outputs = None
108110

109111
if model in V1_SUPPORTED_MODELS:
110-
with monkeypatch.context() as m:
111-
m.setenv("VLLM_USE_V1", "1")
112-
with vllm_runner(model,
113-
max_num_seqs=MAX_NUM_SEQS,
114-
enable_prefix_caching=False) as vllm_model:
115-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
116-
example_prompts, max_tokens, num_logprobs)
112+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
113+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
114+
example_prompts, max_tokens, num_logprobs)
117115
else:
118116
vllm_v1_outputs = None
119117

@@ -136,7 +134,7 @@ def test_models(
136134
)
137135

138136

139-
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
137+
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
140138
@pytest.mark.parametrize("max_tokens", [64])
141139
@pytest.mark.parametrize("num_logprobs", [5])
142140
def test_batching(
@@ -146,10 +144,6 @@ def test_batching(
146144
max_tokens: int,
147145
num_logprobs: int,
148146
) -> None:
149-
if model in V0_UNSUPPORTED_MODELS:
150-
pytest.skip(
151-
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
152-
153147
try:
154148
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
155149
model_info.check_available_online(on_fail="skip")
@@ -187,29 +181,32 @@ def test_chunked_prefill(
187181
max_tokens: int,
188182
num_logprobs: int,
189183
chunked_prefill_token_size: int,
184+
monkeypatch,
190185
) -> None:
191186
max_num_seqs = chunked_prefill_token_size
192187
max_num_batched_tokens = chunked_prefill_token_size
193188

194-
with vllm_runner(model,
195-
enable_chunked_prefill=True,
196-
max_num_batched_tokens=max_num_batched_tokens,
197-
max_num_seqs=max_num_seqs) as vllm_model:
198-
chunked = vllm_model.generate_greedy_logprobs(example_prompts,
199-
max_tokens, num_logprobs)
189+
with monkeypatch.context() as m:
190+
m.setenv("VLLM_USE_V1", "0")
191+
with vllm_runner(model,
192+
enable_chunked_prefill=True,
193+
max_num_batched_tokens=max_num_batched_tokens,
194+
max_num_seqs=max_num_seqs) as vllm_model:
195+
chunked = vllm_model.generate_greedy_logprobs(
196+
example_prompts, max_tokens, num_logprobs)
200197

201-
with vllm_runner(model,
202-
enable_chunked_prefill=False,
203-
max_num_seqs=max_num_seqs) as vllm_model:
204-
non_chunked = vllm_model.generate_greedy_logprobs(
205-
example_prompts, max_tokens, num_logprobs)
198+
with vllm_runner(model,
199+
enable_chunked_prefill=False,
200+
max_num_seqs=max_num_seqs) as vllm_model:
201+
non_chunked = vllm_model.generate_greedy_logprobs(
202+
example_prompts, max_tokens, num_logprobs)
206203

207-
check_logprobs_close(
208-
outputs_0_lst=chunked,
209-
outputs_1_lst=non_chunked,
210-
name_0="chunked",
211-
name_1="non_chunked",
212-
)
204+
check_logprobs_close(
205+
outputs_0_lst=chunked,
206+
outputs_1_lst=non_chunked,
207+
name_0="chunked",
208+
name_1="non_chunked",
209+
)
213210

214211

215212
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@@ -280,25 +277,29 @@ def test_models_preemption_recompute(
280277
example_prompts,
281278
model: str,
282279
max_tokens: int,
280+
monkeypatch,
283281
) -> None:
284282
"""
285283
Tests that outputs are identical with and w/o preemptions (recompute).
286284
"""
287-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
288-
scheduler = vllm_model.llm.llm_engine.scheduler[0]
289-
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
290-
preempt_vllm_outputs = vllm_model.generate_greedy(
291-
example_prompts, max_tokens)
292-
293-
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
294-
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
295-
296-
check_outputs_equal(
297-
outputs_0_lst=preempt_vllm_outputs,
298-
outputs_1_lst=vllm_outputs,
299-
name_0="vllm_preepmtions",
300-
name_1="vllm",
301-
)
285+
with monkeypatch.context() as m:
286+
m.setenv("VLLM_USE_V1", "0")
287+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
288+
scheduler = vllm_model.llm.llm_engine.scheduler[0]
289+
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
290+
preempt_vllm_outputs = vllm_model.generate_greedy(
291+
example_prompts, max_tokens)
292+
293+
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
294+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
295+
max_tokens)
296+
297+
check_outputs_equal(
298+
outputs_0_lst=preempt_vllm_outputs,
299+
outputs_1_lst=vllm_outputs,
300+
name_0="vllm_preepmtions",
301+
name_1="vllm",
302+
)
302303

303304

304305
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@@ -401,24 +402,18 @@ def test_full_cuda_graph(
401402
else:
402403
hf_outputs = None
403404

404-
if model not in V0_UNSUPPORTED_MODELS:
405-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
406-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
407-
example_prompts, max_tokens, num_logprobs)
408-
else:
409-
vllm_v0_outputs = None
410-
411405
with monkeypatch.context() as m:
412-
m.setenv("VLLM_USE_V1", "1")
413-
if model in HYBRID_MODELS:
414-
# required due to reorder_batch behaviour
415-
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
416-
with vllm_runner(model,
417-
max_num_seqs=MAX_NUM_SEQS,
418-
compilation_config={'full_cuda_graph': True},
419-
enable_prefix_caching=False) as vllm_model:
420-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
421-
example_prompts, max_tokens, num_logprobs)
406+
m.setenv("VLLM_USE_V1", "0")
407+
if model not in V0_UNSUPPORTED_MODELS:
408+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
409+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
410+
example_prompts, max_tokens, num_logprobs)
411+
else:
412+
vllm_v0_outputs = None
413+
414+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
415+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
416+
example_prompts, max_tokens, num_logprobs)
422417

423418
if hf_outputs is not None and vllm_v0_outputs is not None:
424419
check_logprobs_close(
@@ -465,24 +460,20 @@ def test_fp32_state(
465460
else:
466461
hf_outputs = None
467462

468-
with vllm_runner(model,
469-
max_num_seqs=MAX_NUM_SEQS,
470-
mamba_ssm_cache_dtype="float32") as vllm_model:
471-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
472-
example_prompts, max_tokens, num_logprobs)
473-
474463
with monkeypatch.context() as m:
475-
m.setenv("VLLM_USE_V1", "1")
476-
if model in HYBRID_MODELS:
477-
# required due to reorder_batch behaviour
478-
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
464+
m.setenv("VLLM_USE_V1", "0")
479465
with vllm_runner(model,
480466
max_num_seqs=MAX_NUM_SEQS,
481-
mamba_ssm_cache_dtype="float32",
482-
enable_prefix_caching=False) as vllm_model:
483-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
467+
mamba_ssm_cache_dtype="float32") as vllm_model:
468+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
484469
example_prompts, max_tokens, num_logprobs)
485470

471+
with vllm_runner(model,
472+
max_num_seqs=MAX_NUM_SEQS,
473+
mamba_ssm_cache_dtype="float32") as vllm_model:
474+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
475+
example_prompts, max_tokens, num_logprobs)
476+
486477
if hf_outputs is not None:
487478
check_logprobs_close(
488479
outputs_0_lst=hf_outputs,

vllm/engine/arg_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,11 +1463,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14631463
recommend_to_remove=False)
14641464
return False
14651465

1466-
# V1 mamba models are unoptimized.
1467-
if model_config.has_inner_state and _warn_or_fallback(
1468-
feature_name="Mamba"):
1469-
return False
1470-
14711466
# No Concurrent Partial Prefills so far.
14721467
if (self.max_num_partial_prefills
14731468
!= SchedulerConfig.max_num_partial_prefills

vllm/model_executor/models/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,4 +417,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
417417
"GptOssForCausalLM": GptOssForCausalLMConfig,
418418
"MambaForCausalLM": MambaModelConfig,
419419
"Mamba2ForCausalLM": MambaModelConfig,
420+
"FalconMambaForCausalLM": MambaModelConfig,
420421
}

0 commit comments

Comments
 (0)