Skip to content

Commit ff6ba0e

Browse files
Josephasafg2015aroras
authored andcommitted
[V1][Mamba] - Enable V1 by default for Mamba Models (vllm-project#23650)
Signed-off-by: asafg <[email protected]>
1 parent 42973fc commit ff6ba0e

File tree

3 files changed

+72
-85
lines changed

3 files changed

+72
-85
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 71 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,19 @@ def test_models(
100100
else:
101101
hf_outputs = None
102102

103-
if model not in V0_UNSUPPORTED_MODELS:
104-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
105-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
106-
example_prompts, max_tokens, num_logprobs)
107-
else:
108-
vllm_v0_outputs = None
103+
with monkeypatch.context() as m:
104+
m.setenv("VLLM_USE_V1", "0")
105+
if model not in V0_UNSUPPORTED_MODELS:
106+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
107+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
108+
example_prompts, max_tokens, num_logprobs)
109+
else:
110+
vllm_v0_outputs = None
109111

110112
if model in V1_SUPPORTED_MODELS:
111-
with monkeypatch.context() as m:
112-
m.setenv("VLLM_USE_V1", "1")
113-
with vllm_runner(model,
114-
max_num_seqs=MAX_NUM_SEQS,
115-
enable_prefix_caching=False) as vllm_model:
116-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
117-
example_prompts, max_tokens, num_logprobs)
113+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
114+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
115+
example_prompts, max_tokens, num_logprobs)
118116
else:
119117
vllm_v1_outputs = None
120118

@@ -137,7 +135,7 @@ def test_models(
137135
)
138136

139137

140-
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
138+
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
141139
@pytest.mark.parametrize("max_tokens", [64])
142140
@pytest.mark.parametrize("num_logprobs", [5])
143141
def test_batching(
@@ -147,10 +145,6 @@ def test_batching(
147145
max_tokens: int,
148146
num_logprobs: int,
149147
) -> None:
150-
if model in V0_UNSUPPORTED_MODELS:
151-
pytest.skip(
152-
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
153-
154148
try:
155149
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
156150
model_info.check_available_online(on_fail="skip")
@@ -188,29 +182,32 @@ def test_chunked_prefill(
188182
max_tokens: int,
189183
num_logprobs: int,
190184
chunked_prefill_token_size: int,
185+
monkeypatch,
191186
) -> None:
192187
max_num_seqs = chunked_prefill_token_size
193188
max_num_batched_tokens = chunked_prefill_token_size
194189

195-
with vllm_runner(model,
196-
enable_chunked_prefill=True,
197-
max_num_batched_tokens=max_num_batched_tokens,
198-
max_num_seqs=max_num_seqs) as vllm_model:
199-
chunked = vllm_model.generate_greedy_logprobs(example_prompts,
200-
max_tokens, num_logprobs)
190+
with monkeypatch.context() as m:
191+
m.setenv("VLLM_USE_V1", "0")
192+
with vllm_runner(model,
193+
enable_chunked_prefill=True,
194+
max_num_batched_tokens=max_num_batched_tokens,
195+
max_num_seqs=max_num_seqs) as vllm_model:
196+
chunked = vllm_model.generate_greedy_logprobs(
197+
example_prompts, max_tokens, num_logprobs)
201198

202-
with vllm_runner(model,
203-
enable_chunked_prefill=False,
204-
max_num_seqs=max_num_seqs) as vllm_model:
205-
non_chunked = vllm_model.generate_greedy_logprobs(
206-
example_prompts, max_tokens, num_logprobs)
199+
with vllm_runner(model,
200+
enable_chunked_prefill=False,
201+
max_num_seqs=max_num_seqs) as vllm_model:
202+
non_chunked = vllm_model.generate_greedy_logprobs(
203+
example_prompts, max_tokens, num_logprobs)
207204

208-
check_logprobs_close(
209-
outputs_0_lst=chunked,
210-
outputs_1_lst=non_chunked,
211-
name_0="chunked",
212-
name_1="non_chunked",
213-
)
205+
check_logprobs_close(
206+
outputs_0_lst=chunked,
207+
outputs_1_lst=non_chunked,
208+
name_0="chunked",
209+
name_1="non_chunked",
210+
)
214211

215212

216213
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@@ -281,25 +278,29 @@ def test_models_preemption_recompute(
281278
example_prompts,
282279
model: str,
283280
max_tokens: int,
281+
monkeypatch,
284282
) -> None:
285283
"""
286284
Tests that outputs are identical with and w/o preemptions (recompute).
287285
"""
288-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
289-
scheduler = vllm_model.llm.llm_engine.scheduler[0]
290-
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
291-
preempt_vllm_outputs = vllm_model.generate_greedy(
292-
example_prompts, max_tokens)
293-
294-
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
295-
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
296-
297-
check_outputs_equal(
298-
outputs_0_lst=preempt_vllm_outputs,
299-
outputs_1_lst=vllm_outputs,
300-
name_0="vllm_preepmtions",
301-
name_1="vllm",
302-
)
286+
with monkeypatch.context() as m:
287+
m.setenv("VLLM_USE_V1", "0")
288+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
289+
scheduler = vllm_model.llm.llm_engine.scheduler[0]
290+
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
291+
preempt_vllm_outputs = vllm_model.generate_greedy(
292+
example_prompts, max_tokens)
293+
294+
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
295+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
296+
max_tokens)
297+
298+
check_outputs_equal(
299+
outputs_0_lst=preempt_vllm_outputs,
300+
outputs_1_lst=vllm_outputs,
301+
name_0="vllm_preepmtions",
302+
name_1="vllm",
303+
)
303304

304305

305306
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@@ -402,24 +403,18 @@ def test_full_cuda_graph(
402403
else:
403404
hf_outputs = None
404405

405-
if model not in V0_UNSUPPORTED_MODELS:
406-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
407-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
408-
example_prompts, max_tokens, num_logprobs)
409-
else:
410-
vllm_v0_outputs = None
411-
412406
with monkeypatch.context() as m:
413-
m.setenv("VLLM_USE_V1", "1")
414-
if model in HYBRID_MODELS:
415-
# required due to reorder_batch behaviour
416-
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
417-
with vllm_runner(model,
418-
max_num_seqs=MAX_NUM_SEQS,
419-
compilation_config={'full_cuda_graph': True},
420-
enable_prefix_caching=False) as vllm_model:
421-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
422-
example_prompts, max_tokens, num_logprobs)
407+
m.setenv("VLLM_USE_V1", "0")
408+
if model not in V0_UNSUPPORTED_MODELS:
409+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
410+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
411+
example_prompts, max_tokens, num_logprobs)
412+
else:
413+
vllm_v0_outputs = None
414+
415+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
416+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
417+
example_prompts, max_tokens, num_logprobs)
423418

424419
if hf_outputs is not None and vllm_v0_outputs is not None:
425420
check_logprobs_close(
@@ -466,24 +461,20 @@ def test_fp32_state(
466461
else:
467462
hf_outputs = None
468463

469-
with vllm_runner(model,
470-
max_num_seqs=MAX_NUM_SEQS,
471-
mamba_ssm_cache_dtype="float32") as vllm_model:
472-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
473-
example_prompts, max_tokens, num_logprobs)
474-
475464
with monkeypatch.context() as m:
476-
m.setenv("VLLM_USE_V1", "1")
477-
if model in HYBRID_MODELS:
478-
# required due to reorder_batch behaviour
479-
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
465+
m.setenv("VLLM_USE_V1", "0")
480466
with vllm_runner(model,
481467
max_num_seqs=MAX_NUM_SEQS,
482-
mamba_ssm_cache_dtype="float32",
483-
enable_prefix_caching=False) as vllm_model:
484-
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
468+
mamba_ssm_cache_dtype="float32") as vllm_model:
469+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
485470
example_prompts, max_tokens, num_logprobs)
486471

472+
with vllm_runner(model,
473+
max_num_seqs=MAX_NUM_SEQS,
474+
mamba_ssm_cache_dtype="float32") as vllm_model:
475+
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
476+
example_prompts, max_tokens, num_logprobs)
477+
487478
if hf_outputs is not None:
488479
check_logprobs_close(
489480
outputs_0_lst=hf_outputs,

vllm/engine/arg_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,11 +1463,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14631463
recommend_to_remove=False)
14641464
return False
14651465

1466-
# V1 mamba models are unoptimized.
1467-
if model_config.has_inner_state and _warn_or_fallback(
1468-
feature_name="Mamba"):
1469-
return False
1470-
14711466
# No Concurrent Partial Prefills so far.
14721467
if (self.max_num_partial_prefills
14731468
!= SchedulerConfig.max_num_partial_prefills

vllm/model_executor/models/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,4 +417,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
417417
"GptOssForCausalLM": GptOssForCausalLMConfig,
418418
"MambaForCausalLM": MambaModelConfig,
419419
"Mamba2ForCausalLM": MambaModelConfig,
420+
"FalconMambaForCausalLM": MambaModelConfig,
420421
}

0 commit comments

Comments
 (0)