Skip to content

Commit 59b453e

Browse files
authored
Speed up mm processor kwargs per request by spliting dynamic and static kwargs (vllm-project#26483)
Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com>
1 parent 827e423 commit 59b453e

File tree

2 files changed

+155
-3
lines changed

2 files changed

+155
-3
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import importlib
5+
6+
from transformers.processing_utils import ProcessingKwargs
7+
from typing_extensions import Unpack
8+
9+
from vllm.transformers_utils.processor import (
10+
get_processor_kwargs_from_processor,
11+
)
12+
13+
14+
class _FakeProcessorKwargs(ProcessingKwargs, total=False): # type: ignore
15+
pass
16+
17+
18+
def _assert_has_all_expected(keys: set[str]) -> None:
19+
# text
20+
for k in ("text_pair", "text_target", "text_pair_target"):
21+
assert k in keys
22+
# image
23+
for k in ("do_convert_rgb", "do_resize"):
24+
assert k in keys
25+
# audio
26+
for k in (
27+
"fps",
28+
"do_sample_frames",
29+
"input_data_format",
30+
"default_to_square",
31+
):
32+
assert k in keys
33+
# audio
34+
for k in ("padding", "return_attention_mask"):
35+
assert k in keys
36+
37+
38+
# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs]
39+
class _ProcWithUnpack:
40+
def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore
41+
return None
42+
43+
44+
def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union():
45+
proc = _ProcWithUnpack()
46+
keys = get_processor_kwargs_from_processor(proc)
47+
_assert_has_all_expected(keys)
48+
49+
50+
# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ----
51+
52+
53+
class _ProcWithoutUnpack:
54+
def __call__(self, *args, **kwargs):
55+
return None
56+
57+
58+
def test_get_processor_kwargs_from_processor_module_scan_returns_full_union():
59+
# ensure the module scanned by fallback is this test module
60+
module_name = _ProcWithoutUnpack.__module__
61+
mod = importlib.import_module(module_name)
62+
assert hasattr(mod, "_FakeProcessorKwargs")
63+
64+
proc = _ProcWithoutUnpack()
65+
keys = get_processor_kwargs_from_processor(proc)
66+
_assert_has_all_expected(keys)

vllm/transformers_utils/processor.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import importlib
5+
import inspect
46
from functools import lru_cache
5-
from typing import TYPE_CHECKING, Any, cast
7+
from typing import TYPE_CHECKING, Any, cast, get_args, get_type_hints
68

79
from transformers import (
810
AutoFeatureExtractor,
@@ -55,6 +57,23 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
5557
return processor_cls
5658

5759

60+
@lru_cache
61+
def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]:
62+
dynamic_kwargs: set[str] = set()
63+
if kwargs_cls is None:
64+
return dynamic_kwargs
65+
# get kwargs annotations in processor
66+
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
67+
kwargs_type_annotations = get_type_hints(kwargs_cls)
68+
for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"):
69+
if kw_type in kwargs_type_annotations:
70+
kw_annotations = get_type_hints(kwargs_type_annotations[kw_type])
71+
for kw_name in kw_annotations:
72+
dynamic_kwargs.add(kw_name)
73+
dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
74+
return dynamic_kwargs
75+
76+
5877
def _merge_mm_kwargs(
5978
model_config: "ModelConfig",
6079
processor_cls: type | tuple[type, ...],
@@ -71,7 +90,6 @@ def _merge_mm_kwargs(
7190
requires_kw_only=False,
7291
allow_var_kwargs=True,
7392
)
74-
7593
# NOTE: Pythonic dict is not hashable and will raise unhashable type
7694
# error when calling `cached_get_processor`, therefore we need to
7795
# wrap it to a hashable dict.
@@ -145,12 +163,80 @@ def get_processor(
145163
cached_get_processor = lru_cache(get_processor)
146164

147165

166+
@lru_cache
167+
def get_processor_kwargs_from_processor(processor: _P) -> set[str]:
168+
try:
169+
# get kwargs annotations in processor
170+
call_kwargs = inspect.signature(type(processor).__call__).parameters.get(
171+
"kwargs"
172+
)
173+
call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
174+
# if the processor has explicit kwargs annotation, use it
175+
if call_kwargs_annotations not in (None, inspect._empty):
176+
# get_type_hints will parse all type annotations at runtime,
177+
# and if an annotation refers to a type or
178+
# name that hasn’t been imported or defined, it will raise an error.
179+
# So we use __annotations__ to get the raw annotations directly.
180+
return _collect_dynamic_keys_from_processing_kwargs(
181+
get_args(call_kwargs_annotations)[0]
182+
)
183+
# otherwise, try to get from ProcessingKwargs
184+
else:
185+
module_name = type(processor).__module__
186+
mod = importlib.import_module(module_name)
187+
# find *ProcessingKwargs in the module
188+
processor_kwargs: set[str] = set()
189+
for name, obj in vars(mod).items():
190+
if name.endswith("ProcessingKwargs"):
191+
processor_kwargs = (
192+
processor_kwargs
193+
| _collect_dynamic_keys_from_processing_kwargs(obj)
194+
)
195+
return processor_kwargs
196+
except Exception:
197+
return set()
198+
199+
200+
def cached_get_processor_without_dynamic_kwargs(
201+
processor_name: str,
202+
*args: Any,
203+
revision: str | None = None,
204+
trust_remote_code: bool = False,
205+
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
206+
**kwargs: Any,
207+
) -> _P:
208+
# Step 1: use default kwargs to get a temporary processor instance
209+
processor = cached_get_processor(
210+
processor_name,
211+
revision=revision,
212+
trust_remote_code=trust_remote_code,
213+
processor_cls=processor_cls, # type: ignore[arg-type]
214+
)
215+
216+
# Step 2: use temporary processor collect dynamic keys
217+
dynamic_keys = get_processor_kwargs_from_processor(processor)
218+
219+
# Step 3: use dynamic_keys filter kwargs
220+
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}
221+
222+
# Step 4: use filtered kwargs to get final processor instance
223+
final_processor = cached_get_processor(
224+
processor_name,
225+
revision=revision,
226+
trust_remote_code=trust_remote_code,
227+
processor_cls=processor_cls, # type: ignore[arg-type]
228+
**filtered_kwargs,
229+
)
230+
231+
return final_processor
232+
233+
148234
def cached_processor_from_config(
149235
model_config: "ModelConfig",
150236
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
151237
**kwargs: Any,
152238
) -> _P:
153-
return cached_get_processor(
239+
return cached_get_processor_without_dynamic_kwargs(
154240
model_config.model,
155241
revision=model_config.revision,
156242
trust_remote_code=model_config.trust_remote_code,

0 commit comments

Comments
 (0)