Skip to content

Commit bbe166c

Browse files
committed
Merge branch 'main' into release/3.3
2 parents f57790c + 3c14051 commit bbe166c

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

swift/llm/infer/deploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import multiprocessing
55
import time
6-
from contextlib import contextmanager, nullcontext
6+
from contextlib import contextmanager
77
from dataclasses import asdict
88
from http import HTTPStatus
99
from threading import Thread

swift/llm/infer/infer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from contextlib import nullcontext
32
from typing import Any, Dict, List, Union
43

54
import numpy as np
6-
import torch.distributed as dist
75
from datasets import Dataset as HfDataset
86

97
from swift.llm import InferArguments, InferRequest, SwiftPipeline, load_dataset, prepare_model_template, sample_dataset

swift/llm/infer/infer_engine/vllm_engine.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def _prepare_engine_kwargs(
187187
def _fix_vllm_bug(self) -> None:
188188
# fix vllm==0.4 bug (very slow)
189189
tokenizer = self.tokenizer
190-
if version.parse(
191-
vllm.__version__) >= version.parse('0.4') and not tokenizer.__class__.__name__.startswith('Cached'):
190+
if self._version_ge('0.4') and not tokenizer.__class__.__name__.startswith('Cached'):
192191
_tokenizer_len = len(tokenizer)
193192
__old_len__ = tokenizer.__class__.__len__
194193

@@ -224,6 +223,13 @@ def _add_stop_words(self, generation_config: SamplingParams, request_config: Req
224223
stop_words = (request_config.stop or []) + (self.generation_config.stop or []) + template_meta.stop_words
225224
generation_config.stop = self._get_stop_words(stop_words)
226225

226+
@staticmethod
227+
def _version_ge(base_version: str):
228+
vllm_version = vllm.__version__
229+
if vllm_version is None or 'dev' in vllm_version:
230+
return True
231+
return version.parse(vllm_version) >= version.parse(base_version)
232+
227233
def _add_request(self,
228234
inputs: Dict[str, Any],
229235
generation_config: SamplingParams,
@@ -241,18 +247,18 @@ def _add_request(self,
241247
lora_name=adapter_name, lora_path=adapter_path, lora_int_id=len(self._adapters_pool) + 1)
242248
self._adapters_pool[adapter_name] = kwargs['lora_request']
243249
input_ids = inputs['input_ids']
244-
if version.parse(vllm.__version__) >= version.parse('0.4.3'):
250+
if self._version_ge('0.4.3'):
245251
llm_inputs = {'prompt_token_ids': input_ids}
246252
mm_data = {}
247253
for key in ['images', 'audios', 'videos']:
248254
media_data = inputs.get(key) or []
249255
if media_data:
250-
if version.parse(vllm.__version__) < version.parse('0.6'):
256+
if self._version_ge('0.6'):
257+
mm_data = {key.rstrip('s'): media_data[0] if len(media_data) == 1 else media_data}
258+
else:
251259
assert len(media_data) == 1, (
252260
f'The current version of vllm only supports single {key}. Please upgrade to vllm >= 0.6.0')
253261
mm_data = {key.rstrip('s'): media_data[0]}
254-
else:
255-
mm_data = {key.rstrip('s'): media_data[0] if len(media_data) == 1 else media_data}
256262
if mm_data:
257263
llm_inputs['multi_modal_data'] = mm_data
258264
if self.use_async_engine:

tests/test_align/test_vllm_vlm.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ def _infer_image(model, use_chat_template: bool = True, max_model_len=8192, syst
3131
return resp_list[0].choices[0].message.content
3232

3333

34-
def _infer_video(model, use_chat_template: bool = True, max_model_len=8192, system=None):
35-
engine = VllmEngine(model, max_model_len=max_model_len, limit_mm_per_prompt={'image': 16, 'video': 2})
34+
def _infer_video(model, use_chat_template: bool = True, max_model_len=8192, system=None, limit_mm_per_prompt=None):
35+
limit_mm_per_prompt = limit_mm_per_prompt or {'image': 16, 'video': 2}
36+
engine = VllmEngine(model, max_model_len=max_model_len, limit_mm_per_prompt=limit_mm_per_prompt)
3637
if not use_chat_template:
3738
engine.default_template.use_chat_template = False
3839
videos = ['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4']
@@ -116,6 +117,12 @@ def test_qwen2_5_vl_video():
116117
'on a white blanket. The baby is looking at the book and is smiling. The baby')
117118

118119

120+
def test_qwen2_5_omni():
121+
limit_mm_per_prompt = {'image': 1, 'video': 1, 'audio': 1}
122+
response = _infer_video('Qwen/Qwen2.5-Omni-7B', limit_mm_per_prompt=limit_mm_per_prompt)
123+
assert response
124+
125+
119126
if __name__ == '__main__':
120127
from swift.llm import VllmEngine, InferRequest, RequestConfig
121128
# test_qwen2_vl()
@@ -125,5 +132,6 @@ def test_qwen2_5_vl_video():
125132
# test_qwen2_audio()
126133
# test_minicpmv_2_5()
127134
# test_minicpmv_2_6()
128-
test_minicpmo_2_6_video()
135+
# test_minicpmo_2_6_video()
129136
# test_qwen2_5_vl_video()
137+
test_qwen2_5_omni()

0 commit comments

Comments
 (0)