@@ -187,8 +187,7 @@ def _prepare_engine_kwargs(
187187 def _fix_vllm_bug (self ) -> None :
188188 # fix vllm==0.4 bug (very slow)
189189 tokenizer = self .tokenizer
190- if version .parse (
191- vllm .__version__ ) >= version .parse ('0.4' ) and not tokenizer .__class__ .__name__ .startswith ('Cached' ):
190+ if self ._version_ge ('0.4' ) and not tokenizer .__class__ .__name__ .startswith ('Cached' ):
192191 _tokenizer_len = len (tokenizer )
193192 __old_len__ = tokenizer .__class__ .__len__
194193
@@ -224,6 +223,13 @@ def _add_stop_words(self, generation_config: SamplingParams, request_config: Req
224223 stop_words = (request_config .stop or []) + (self .generation_config .stop or []) + template_meta .stop_words
225224 generation_config .stop = self ._get_stop_words (stop_words )
226225
226+ @staticmethod
227+ def _version_ge (base_version : str ):
228+ vllm_version = vllm .__version__
229+ if vllm_version is None or 'dev' in vllm_version :
230+ return True
231+ return version .parse (vllm_version ) >= version .parse (base_version )
232+
227233 def _add_request (self ,
228234 inputs : Dict [str , Any ],
229235 generation_config : SamplingParams ,
@@ -241,18 +247,18 @@ def _add_request(self,
241247 lora_name = adapter_name , lora_path = adapter_path , lora_int_id = len (self ._adapters_pool ) + 1 )
242248 self ._adapters_pool [adapter_name ] = kwargs ['lora_request' ]
243249 input_ids = inputs ['input_ids' ]
244- if version . parse ( vllm . __version__ ) >= version . parse ('0.4.3' ):
250+ if self . _version_ge ('0.4.3' ):
245251 llm_inputs = {'prompt_token_ids' : input_ids }
246252 mm_data = {}
247253 for key in ['images' , 'audios' , 'videos' ]:
248254 media_data = inputs .get (key ) or []
249255 if media_data :
250- if version .parse (vllm .__version__ ) < version .parse ('0.6' ):
256+ if self ._version_ge ('0.6' ):
257+ mm_data = {key .rstrip ('s' ): media_data [0 ] if len (media_data ) == 1 else media_data }
258+ else :
251259 assert len (media_data ) == 1 , (
252260 f'The current version of vllm only supports single { key } . Please upgrade to vllm >= 0.6.0' )
253261 mm_data = {key .rstrip ('s' ): media_data [0 ]}
254- else :
255- mm_data = {key .rstrip ('s' ): media_data [0 ] if len (media_data ) == 1 else media_data }
256262 if mm_data :
257263 llm_inputs ['multi_modal_data' ] = mm_data
258264 if self .use_async_engine :
0 commit comments