Skip to content

Commit 686623c

Browse files
authored
Fix nullable_kvs fallback (vllm-project#16837)
Signed-off-by: Harry Mellor <[email protected]>
1 parent aadb656 commit 686623c

File tree

8 files changed

+27
-19
lines changed

8 files changed

+27
-19
lines changed

tests/engine/test_arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
@pytest.mark.parametrize(("arg", "expected"), [
13-
(None, None),
13+
(None, dict()),
1414
("image=16", {
1515
"image": 16
1616
}),

tests/entrypoints/openai/test_audio.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
4+
35
import openai
46
import pytest
57
import pytest_asyncio
@@ -27,7 +29,7 @@ def server():
2729
"--enforce-eager",
2830
"--trust-remote-code",
2931
"--limit-mm-per-prompt",
30-
str({"audio": MAXIMUM_AUDIOS}),
32+
json.dumps({"audio": MAXIMUM_AUDIOS}),
3133
]
3234

3335
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/entrypoints/openai/test_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
4+
35
import openai
46
import pytest
57
import pytest_asyncio
@@ -31,7 +33,7 @@ def server():
3133
"--enforce-eager",
3234
"--trust-remote-code",
3335
"--limit-mm-per-prompt",
34-
str({"video": MAXIMUM_VIDEOS}),
36+
json.dumps({"video": MAXIMUM_VIDEOS}),
3537
]
3638

3739
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/entrypoints/openai/test_vision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
4+
35
import openai
46
import pytest
57
import pytest_asyncio
@@ -35,7 +37,7 @@ def server():
3537
"--enforce-eager",
3638
"--trust-remote-code",
3739
"--limit-mm-per-prompt",
38-
str({"image": MAXIMUM_IMAGES}),
40+
json.dumps({"image": MAXIMUM_IMAGES}),
3941
]
4042

4143
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/entrypoints/openai/test_vision_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
4+
35
import pytest
46
import requests
57
from PIL import Image
@@ -37,7 +39,7 @@ def server():
3739
"--enforce-eager",
3840
"--trust-remote-code",
3941
"--limit-mm-per-prompt",
40-
str({"image": MAXIMUM_IMAGES}),
42+
json.dumps({"image": MAXIMUM_IMAGES}),
4143
"--chat-template",
4244
str(vlm2vec_jinja_path),
4345
]

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
34
from typing import Optional
45

56
import numpy as np
@@ -50,7 +51,7 @@ def server(request, audio_assets):
5051
args = [
5152
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
5253
"--limit-mm-per-prompt",
53-
str({"audio": len(audio_assets)}), "--trust-remote-code"
54+
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
5455
] + [
5556
f"--{key.replace('_','-')}={value}"
5657
for key, value in request.param.items()

vllm/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import textwrap
1111
import warnings
1212
from collections import Counter
13-
from collections.abc import Mapping
1413
from contextlib import contextmanager
1514
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
1615
replace)
@@ -355,7 +354,7 @@ def __init__(
355354
disable_cascade_attn: bool = False,
356355
skip_tokenizer_init: bool = False,
357356
served_model_name: Optional[Union[str, list[str]]] = None,
358-
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
357+
limit_mm_per_prompt: Optional[dict[str, int]] = None,
359358
use_async_output_proc: bool = True,
360359
config_format: ConfigFormat = ConfigFormat.AUTO,
361360
hf_token: Optional[Union[bool, str]] = None,
@@ -578,7 +577,7 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,
578577
self.tokenizer = s3_tokenizer.dir
579578

580579
def _init_multimodal_config(
581-
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
580+
self, limit_mm_per_prompt: Optional[dict[str, int]]
582581
) -> Optional["MultiModalConfig"]:
583582
if self.registry.is_multimodal_model(self.architectures):
584583
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
@@ -2730,7 +2729,7 @@ def verify_with_model_config(self, model_config: ModelConfig):
27302729
class MultiModalConfig:
27312730
"""Controls the behavior of multimodal models."""
27322731

2733-
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
2732+
limit_per_prompt: dict[str, int] = field(default_factory=dict)
27342733
"""
27352734
The maximum number of input items allowed per prompt for each modality.
27362735
This should be a JSON string that will be parsed into a dictionary.

vllm/engine/arg_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import threading
99
from dataclasses import MISSING, dataclass, fields
10-
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping,
10+
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
1111
Optional, Tuple, Type, TypeVar, Union, cast, get_args,
1212
get_origin)
1313

@@ -112,14 +112,14 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
112112

113113

114114
def optional_dict(val: str) -> Optional[dict[str, int]]:
115-
try:
115+
if re.match("^{.*}$", val):
116116
return optional_arg(val, json.loads)
117-
except ValueError:
118-
logger.warning(
119-
"Failed to parse JSON string. Attempting to parse as "
120-
"comma-separated key=value pairs. This will be deprecated in a "
121-
"future release.")
122-
return nullable_kvs(val)
117+
118+
logger.warning(
119+
"Failed to parse JSON string. Attempting to parse as "
120+
"comma-separated key=value pairs. This will be deprecated in a "
121+
"future release.")
122+
return nullable_kvs(val)
123123

124124

125125
@dataclass
@@ -191,7 +191,7 @@ class EngineArgs:
191191
TokenizerPoolConfig.pool_type
192192
tokenizer_pool_extra_config: dict[str, Any] = \
193193
get_field(TokenizerPoolConfig, "extra_config")
194-
limit_mm_per_prompt: Mapping[str, int] = \
194+
limit_mm_per_prompt: dict[str, int] = \
195195
get_field(MultiModalConfig, "limit_per_prompt")
196196
mm_processor_kwargs: Optional[Dict[str, Any]] = None
197197
disable_mm_preprocessor_cache: bool = False

0 commit comments

Comments
 (0)