Skip to content

Commit 62a4403

Browse files
committed
support interleave text and image in messages
1 parent 1a859f4 commit 62a4403

File tree

3 files changed

+95
-18
lines changed

3 files changed

+95
-18
lines changed

lmdeploy/messages.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,33 @@ class Response:
473473
index: int = 0
474474
routed_experts: Any = None
475475

476+
def __str__(self):
477+
fields = []
478+
479+
fields.append('text=')
480+
fields.append(self.text if self.text is not None else 'None')
481+
fields.append(f'input_token_len={self.input_token_len}')
482+
fields.append(f'generate_token_len={self.generate_token_len}')
483+
fields.append(f'finish_reason="{self.finish_reason}"')
484+
fields.append(f'token_ids={self.token_ids}')
485+
fields.append(f'logprobs={self.logprobs}')
486+
487+
# Helper function to format tensor information
488+
def _format_tensor(name: str, tensor: Optional[torch.Tensor]) -> List[str]:
489+
if tensor is None:
490+
return [f'{name}=None']
491+
return [f'{name}.shape={tensor.shape}', f'{name}={tensor}']
492+
493+
# Format tensor fields
494+
fields.extend(_format_tensor('logits', self.logits))
495+
fields.extend(_format_tensor('last_hidden_state', self.last_hidden_state))
496+
497+
if self.routed_experts is None:
498+
fields.append('routed_experts=None')
499+
else:
500+
fields.append(f'routed_experts.shape={self.routed_experts.shape}')
501+
return '\n'.join(fields)
502+
476503
def __repr__(self):
477504
logits = 'logits=None' if self.logits is None else f'logits.shape={self.logits.shape}\nlogits={self.logits}'
478505
hidden_state = (

lmdeploy/vl/model/internvl3_hf.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self,
4444
hf_config: AutoConfig = None,
4545
backend: str = ''):
4646
super().__init__(model_path, with_llm, max_memory, hf_config, backend)
47-
self.arch = hf_config.architectures[0]
47+
self.arch = self.hf_config.architectures[0]
4848

4949
def build_preprocessor(self):
5050
self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
@@ -146,8 +146,32 @@ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
146146
messages.append(dict(role='forward', content=outputs))
147147
return messages
148148

149-
@staticmethod
149+
def proc_internvl_hf_messages(self, content: List[Dict], IMAGE_TOKEN: str):
150+
"""Process the content list of role 'user' for InternVL HF models."""
151+
res = []
152+
for item in content:
153+
if item['type'] == 'text':
154+
res.append(item['text'])
155+
elif item['type'] in ['image', 'image_url']:
156+
res.append(f'{IMAGE_TOKEN}\n')
157+
else:
158+
raise ValueError(f'Unsupported message type: {item["type"]}')
159+
return ''.join(res)
160+
161+
def proc_interns1_messages(self, content: List[Dict], IMAGE_TOKEN: str):
162+
"""Process the content list of role 'user' for InternS1 models."""
163+
res = []
164+
for item in content:
165+
if item['type'] == 'text':
166+
res.append(item['text'])
167+
elif item['type'] in ['image', 'image_url']:
168+
res.append(IMAGE_TOKEN)
169+
else:
170+
raise ValueError(f'Unsupported message type: {item["type"]}')
171+
return '\n'.join(res)
172+
150173
def proc_messages(
174+
self,
151175
messages,
152176
chat_template,
153177
sequence_start,
@@ -158,24 +182,17 @@ def proc_messages(
158182
prompt_messages = []
159183
IMAGE_TOKEN = '<IMAGE_TOKEN>'
160184
for message in messages:
161-
if isinstance(message['content'], str):
162-
prompt_messages.append(message)
185+
if message['role'] in ['preprocess', 'forward']:
163186
continue
164-
elif message['role'] in ['preprocess', 'forward']:
165-
continue
166-
n_images = len([1 for x in message['content'] if x['type'] == 'image'])
167-
content = [x.get('text', '') for x in message['content'] if x['type'] == 'text']
168-
prompt = content[0]
169-
if IMAGE_TOKEN in prompt and f'<img>{IMAGE_TOKEN}' not in prompt:
170-
prompt = prompt.replace(f'{IMAGE_TOKEN}', f'<img>{IMAGE_TOKEN}</img>')
171-
prompt = prompt.replace('</img><img>', '')
172-
prompt = prompt.replace('<img><img>', '<img>')
173-
prompt = prompt.replace('</img></img>', '</img>')
174-
elif IMAGE_TOKEN not in prompt:
175-
prompt = f'<img>{IMAGE_TOKEN * n_images}</img>\n' + prompt
187+
role, content = message['role'], message['content']
188+
if role == 'user' and isinstance(content, List):
189+
content = (self.proc_internvl_hf_messages(content, IMAGE_TOKEN) if self.arch
190+
== 'InternVLForConditionalGeneration' else self.proc_interns1_messages(content, IMAGE_TOKEN))
191+
message = dict(role=role, content=content)
192+
prompt_messages.append(message)
176193
else:
177-
pass
178-
prompt_messages.append(dict(role='user', content=prompt))
194+
prompt_messages.append(message)
195+
179196
prompt = chat_template.messages2prompt(prompt_messages,
180197
sequence_start,
181198
tools=tools,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
from lmdeploy.model import HFChatTemplate
4+
from lmdeploy.vl.model.internvl3_hf import InternVL3VisionModel
5+
6+
TEST_MODELS = ['OpenGVLab/InternVL3_5-8B-HF', 'internlm/Intern-S1-mini']
7+
8+
9+
@pytest.fixture(scope='module')
10+
def mock_messages():
11+
return [
12+
dict(role='user',
13+
content=[
14+
dict(type='text', text='Describe the following images in detail'),
15+
dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
16+
dict(type='image', url=dict(url='http://images.cocodataset.org/val2017/000000039769.jpg')),
17+
dict(type='text', text='How many cats are there in total?')
18+
]),
19+
]
20+
21+
22+
def test_proc_messages(mock_messages):
23+
for model_path in TEST_MODELS:
24+
vision_model = InternVL3VisionModel(model_path=model_path, with_llm=False)
25+
vision_model.build_preprocessor()
26+
reference = vision_model.processor.apply_chat_template(mock_messages,
27+
add_generation_prompt=True,
28+
tokenize=False,
29+
return_dict=True)
30+
chat_template = HFChatTemplate(model_path=model_path)
31+
vision_model.proc_messages(mock_messages, chat_template, sequence_start=True)
32+
prompt, _ = vision_model.proc_messages(mock_messages, chat_template, sequence_start=True)
33+
assert prompt.replace('<IMAGE_TOKEN>', '<IMG_CONTEXT>') == reference

0 commit comments

Comments
 (0)