Skip to content

Commit fe8d7b6

Browse files
[Model] Interface to enable batch-level DP support (vllm-project#23733)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Cyrus Leung <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 16dc405 commit fe8d7b6

File tree

8 files changed

+38
-4
lines changed

8 files changed

+38
-4
lines changed

docs/configuration/optimization.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,11 @@ llm = LLM(
168168
Batch-level DP is not to be confused with API request-level DP
169169
(which is instead controlled by `data_parallel_size`).
170170

171-
The availability of batch-level DP is based on model implementation.
172-
Currently, the following models support `mm_encoder_tp_mode="data"`:
171+
Batch-level DP needs to be implemented on a per-model basis,
172+
and enabled by setting `supports_encoder_tp_data = True` in the model class.
173+
Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature.
174+
175+
Known supported models:
173176

174177
- Llama4 (<gh-pr:18368>)
175178
- MiniCPM-V-4 (<gh-pr:23327>)

vllm/config/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,13 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,
872872

873873
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
874874
if self._model_info.supports_multimodal:
875+
if (self.mm_encoder_tp_mode == "data" and
876+
not self._model_info.supports_multimodal_encoder_tp_data):
877+
logger.warning_once(
878+
"This model does not support `--mm-encoder-tp-mode data`. "
879+
"Falling back to `--mm-encoder-tp-mode weights`.")
880+
self.mm_encoder_tp_mode = "weights"
881+
875882
return MultiModalConfig(
876883
limit_per_prompt=self.limit_mm_per_prompt,
877884
media_io_kwargs=self.media_io_kwargs,

vllm/model_executor/models/interfaces.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol):
5252
MRO of your model class.
5353
"""
5454

55+
supports_encoder_tp_data: ClassVar[bool] = False
56+
"""
57+
A flag that indicates whether this model supports
58+
`multimodal_config.mm_encoder_tp_mode="data"`.
59+
"""
60+
5561
@classmethod
5662
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
5763
"""
@@ -137,6 +143,11 @@ def supports_multimodal(
137143
return getattr(model, "supports_multimodal", False)
138144

139145

146+
def supports_multimodal_encoder_tp_data(
147+
model: Union[type[object], object]) -> bool:
148+
return getattr(model, "supports_encoder_tp_data", False)
149+
150+
140151
@runtime_checkable
141152
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
142153
"""The interface required for all multi-modal models."""

vllm/model_executor/models/minicpmv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
15211521
],
15221522
}
15231523

1524+
supports_encoder_tp_data = True
1525+
15241526
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
15251527
super().__init__(vllm_config=vllm_config, prefix=prefix)
15261528
assert self.version == (4, 0)

vllm/model_executor/models/mllama4.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
716716
"gate_up_proj": ["gate_proj", "up_proj"],
717717
}
718718

719+
supports_encoder_tp_data = True
720+
719721
@classmethod
720722
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
721723
if modality.startswith("image"):

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
868868
"model.": "language_model.model.",
869869
})
870870

871+
supports_encoder_tp_data = True
872+
871873
@classmethod
872874
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
873875
if modality.startswith("image"):

vllm/model_executor/models/registry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727

2828
from .interfaces import (has_inner_state, has_noops, is_attention_free,
2929
is_hybrid, supports_cross_encoding,
30-
supports_multimodal, supports_multimodal_raw_input,
31-
supports_pp, supports_transcription, supports_v0_only)
30+
supports_multimodal,
31+
supports_multimodal_encoder_tp_data,
32+
supports_multimodal_raw_input, supports_pp,
33+
supports_transcription, supports_v0_only)
3234
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
3335
is_text_generation_model)
3436

@@ -324,6 +326,7 @@ class _ModelInfo:
324326
supports_cross_encoding: bool
325327
supports_multimodal: bool
326328
supports_multimodal_raw_input: bool
329+
supports_multimodal_encoder_tp_data: bool
327330
supports_pp: bool
328331
has_inner_state: bool
329332
is_attention_free: bool
@@ -343,6 +346,8 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
343346
supports_cross_encoding=supports_cross_encoding(model),
344347
supports_multimodal=supports_multimodal(model),
345348
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
349+
supports_multimodal_encoder_tp_data=
350+
supports_multimodal_encoder_tp_data(model),
346351
supports_pp=supports_pp(model),
347352
has_inner_state=has_inner_state(model),
348353
is_attention_free=is_attention_free(model),

vllm/model_executor/models/step3_vl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
867867
"lm_head.": "language_model.lm_head.",
868868
})
869869

870+
supports_encoder_tp_data = True
871+
870872
@classmethod
871873
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
872874
if modality.startswith("image"):

0 commit comments

Comments
 (0)