Skip to content

Commit 9222a55

Browse files
committed
add UT
1 parent 899f428 commit 9222a55

File tree

2 files changed

+60
-9
lines changed

2 files changed

+60
-9
lines changed

lmdeploy/vl/model/internvl3_hf.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,30 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
147147
messages.append(dict(role='forward', content=outputs))
148148
return messages
149149

150+
def proc_internvl_hf_messages(self, content: List[Dict]):
151+
"""Process the content list of role 'user' for InternVL HF models."""
152+
res = []
153+
for item in content:
154+
if item['type'] == 'text':
155+
res.append(item['text'])
156+
elif item['type'] in ['image', 'image_url']:
157+
res.append(f'{self.image_token}\n')
158+
else:
159+
raise ValueError(f'Unsupported message type: {item["type"]}')
160+
return ''.join(res)
161+
162+
def proc_interns1_messages(self, content: List[Dict]):
163+
"""Process the content list of role 'user' for InternS1 models."""
164+
res = []
165+
for item in content:
166+
if item['type'] == 'text':
167+
res.append(item['text'])
168+
elif item['type'] in ['image', 'image_url']:
169+
res.append(f'{self.image_token}')
170+
else:
171+
raise ValueError(f'Unsupported message type: {item["type"]}')
172+
return '\n'.join(res)
173+
150174
def proc_messages(
151175
self,
152176
messages,
@@ -162,15 +186,9 @@ def proc_messages(
162186
continue
163187
role, content = message['role'], message['content']
164188
if role == 'user' and isinstance(content, List):
165-
_content = []
166-
for item in content:
167-
if item['type'] == 'text':
168-
_content.append(item['text'])
169-
elif item['type'] in ['image', 'image_url']:
170-
_content.append(self.image_token)
171-
else:
172-
raise ValueError(f'Unsupported message type: {item["type"]}')
173-
message = dict(role=role, content='\n'.join(_content))
189+
content = (self.proc_internvl_hf_messages(content)
190+
if self.arch == 'InternVLForConditionalGeneration' else self.proc_interns1_messages(content))
191+
message = dict(role=role, content=content)
174192
prompt_messages.append(message)
175193
else:
176194
prompt_messages.append(message)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
from lmdeploy.model import HFChatTemplate
4+
from lmdeploy.vl.model.internvl3_hf import InternVL3VisionModel
5+
6+
TEST_MODELS = ['OpenGVLab/InternVL3_5-8B-HF', 'internlm/Intern-S1-mini']
7+
8+
9+
@pytest.fixture(scope='module')
10+
def mock_messages():
11+
return [
12+
dict(role='user',
13+
content=[
14+
dict(type='text', text='Describe the following images in detail'),
15+
dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
16+
dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
17+
dict(type='text', text='How many cats are there in total?')
18+
]),
19+
]
20+
21+
22+
def test_proc_messages(mock_messages):
23+
for model_path in TEST_MODELS:
24+
vision_model = InternVL3VisionModel(model_path=model_path, with_llm=False)
25+
vision_model.build_preprocessor()
26+
reference = vision_model.processor.apply_chat_template(mock_messages,
27+
add_generation_prompt=True,
28+
tokenize=False,
29+
return_dict=True)
30+
chat_template = HFChatTemplate(model_path=model_path)
31+
vision_model.proc_messages(mock_messages, chat_template, sequence_start=True)
32+
prompt, _ = vision_model.proc_messages(mock_messages, chat_template, sequence_start=True)
33+
assert prompt == reference

0 commit comments

Comments
 (0)