44from dataclasses import dataclass
55from typing import TYPE_CHECKING , Any , NamedTuple , Optional , Union
66
7+ import torch
78from transformers import BatchFeature , PretrainedConfig , ProcessorMixin
89from typing_extensions import TypeVar
910
11+ from vllm .jsontree import JSONTree , json_map_leaves
12+ from vllm .logger import init_logger
1013from vllm .transformers_utils .processor import cached_processor_from_config
1114from vllm .transformers_utils .tokenizer import AnyTokenizer
1215from vllm .utils import resolve_mm_processor_kwargs
2124_C = TypeVar ("_C" , bound = PretrainedConfig , default = PretrainedConfig )
2225_P = TypeVar ("_P" , bound = ProcessorMixin , default = ProcessorMixin )
2326
27+ logger = init_logger (__name__ )
28+
2429
2530@dataclass (frozen = True )
2631class InputContext :
@@ -134,7 +139,7 @@ def call_hf_processor(
134139 hf_processor : ProcessorMixin ,
135140 data : Mapping [str , object ],
136141 kwargs : Mapping [str , object ] = {},
137- ) -> BatchFeature :
142+ ) -> Union [ BatchFeature , JSONTree ] :
138143 """
139144 Call `hf_processor` on the prompt `data`
140145 (text, image, audio...) with configurable options `kwargs`.
@@ -154,8 +159,25 @@ def call_hf_processor(
154159 allow_var_kwargs = True ,
155160 )
156161
162+ def maybe_cast_dtype (x ):
163+ # This mimics the behavior of transformers.BatchFeature
164+ if isinstance (x , torch .Tensor ) and x .is_floating_point ():
165+ return x .to (dtype = self .model_config .dtype )
166+ return x
167+
157168 try :
158- return hf_processor (** data , ** merged_kwargs , return_tensors = "pt" )
169+ output = hf_processor (** data , ** merged_kwargs , return_tensors = "pt" )
170+ # this emulates output.to(dtype=self.model_config.dtype)
171+ cast_output = json_map_leaves (maybe_cast_dtype , output )
172+ if isinstance (output , BatchFeature ):
173+ return BatchFeature (cast_output )
174+
175+ logger .warning_once (
176+ f"{ type (hf_processor ).__name__ } did not return `BatchFeature`. "
177+ "Make sure to match the behaviour of `ProcessorMixin` when "
178+ "implementing custom processors." )
179+ return cast_output
180+
159181 except Exception as exc :
160182 msg = (f"Failed to apply { type (hf_processor ).__name__ } "
161183 f"on data={ data } with kwargs={ merged_kwargs } " )
0 commit comments