Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit df1a211

Browse files
authored
[Model] Fix Phi-3.5-vision-instruct 'num_crops' issue (vllm-project#7710)
1 parent 7937009 commit df1a211

File tree

6 files changed

+37
-13
lines changed

6 files changed

+37
-13
lines changed

docs/source/models/supported_models.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ Multimodal Language Models
225225
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
226226
-
227227
* - :code:`Phi3VForCausalLM`
228-
- Phi-3-Vision
228+
- Phi-3-Vision, Phi-3.5-Vision
229229
- Image
230-
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
230+
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
231231
-
232232
* - :code:`MiniCPMV`
233233
- MiniCPM-V

tests/models/test_phi3v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
2222
})
2323

24-
models = ["microsoft/Phi-3-vision-128k-instruct"]
24+
models = ["microsoft/Phi-3.5-vision-instruct"]
2525

2626

2727
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,

vllm/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from vllm.model_executor.models import ModelRegistry
1414
from vllm.platforms import current_platform
1515
from vllm.tracing import is_otel_available, otel_import_error_traceback
16-
from vllm.transformers_utils.config import get_config, get_hf_text_config
16+
from vllm.transformers_utils.config import (get_config,
17+
get_hf_image_processor_config,
18+
get_hf_text_config)
1719
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
1820
cuda_device_count_stateless, get_cpu_memory, is_cpu,
1921
is_hip, is_neuron, is_openvino, is_xpu,
@@ -167,6 +169,8 @@ def __init__(
167169
self.hf_config = get_config(self.model, trust_remote_code, revision,
168170
code_revision, rope_scaling, rope_theta)
169171
self.hf_text_config = get_hf_text_config(self.hf_config)
172+
self.hf_image_processor_config = get_hf_image_processor_config(
173+
self.model, revision)
170174
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
171175

172176
# Choose a default enforce_eager value if the user did not specify

vllm/inputs/registry.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from array import array
33
from collections import UserDict
44
from dataclasses import dataclass
5-
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
6-
Tuple, Type)
5+
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
6+
Protocol, Tuple, Type)
77

88
from torch import nn
99
from transformers import PretrainedConfig
@@ -55,6 +55,13 @@ def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
5555

5656
return hf_config
5757

58+
def get_hf_image_processor_config(self) -> Dict[str, Any]:
59+
"""
60+
Get the HuggingFace image processor configuration of the model.
61+
"""
62+
63+
return self.model_config.hf_image_processor_config
64+
5865

5966
N = TypeVar("N", bound=Type[nn.Module])
6067

vllm/model_executor/models/phi3v.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# limitations under the License.
1616
import re
1717
from functools import lru_cache
18-
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
19-
TypedDict, Union)
18+
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
19+
Tuple, TypedDict, Union)
2020

2121
import numpy as np
2222
import torch
@@ -324,12 +324,12 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
324324

325325
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
326326
def get_phi3v_image_feature_size(
327-
hf_config: PretrainedConfig,
327+
hf_config: Dict[str, Any],
328328
*,
329329
input_height: int,
330330
input_width: int,
331331
) -> int:
332-
num_crops = getattr(hf_config, "num_crops", 16)
332+
num_crops = hf_config.get("num_crops", 16)
333333
new_width, new_height = _calc_hd_transform_size(width=input_width,
334334
height=input_height,
335335
hd_num=num_crops)
@@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
341341
def get_max_phi3v_image_tokens(ctx: InputContext):
342342

343343
return get_phi3v_image_feature_size(
344-
ctx.get_hf_config(),
344+
ctx.get_hf_image_processor_config(),
345345
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
346346
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
347347
)
@@ -395,7 +395,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
395395
return llm_inputs
396396

397397
model_config = ctx.model_config
398-
hf_config = ctx.get_hf_config()
398+
hf_config = ctx.get_hf_image_processor_config()
399399

400400
image_data = multi_modal_data["image"]
401401
if isinstance(image_data, Image.Image):

vllm/transformers_utils/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import contextlib
22
from pathlib import Path
3-
from typing import Dict, Optional, Type, Union
3+
from typing import Any, Dict, Optional, Type, Union
44

55
from transformers import GenerationConfig, PretrainedConfig
6+
from transformers.models.auto.image_processing_auto import (
7+
get_image_processor_config)
68
from transformers.models.auto.modeling_auto import (
79
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
810

@@ -98,6 +100,17 @@ def get_config(
98100
return config
99101

100102

103+
def get_hf_image_processor_config(
104+
model: Union[str, Path],
105+
revision: Optional[str] = None,
106+
**kwargs,
107+
) -> Dict[str, Any]:
108+
# Separate model folder from file path for GGUF models
109+
if Path(model).is_file() and Path(model).suffix == ".gguf":
110+
model = Path(model).parent
111+
return get_image_processor_config(model, revision=revision, **kwargs)
112+
113+
101114
def get_hf_text_config(config: PretrainedConfig):
102115
"""Get the "sub" config relevant to llm for multi modal models.
103116
No op for pure text models.

0 commit comments

Comments
 (0)