Skip to content

Commit 07e16a9

Browse files
authored
Fix mplug owl2, molmo (#2724)
1 parent 1bcb9bb commit 07e16a9

File tree

11 files changed

+94
-187
lines changed

11 files changed

+94
-187
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
You can contact us and communicate with us by adding our group:
5050

5151

52-
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
52+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
5353
:-------------------------:|:-------------------------:
5454
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
5555

swift/llm/infer/infer_engine/pt_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,10 @@ def _infer_stream(self,
168168
if generation_config.output_logits:
169169
generate_kwargs['logits_processor'] = LogitsProcessorList([LogitsStreamer()])
170170

171-
def _model_generate(*args, **kwargs):
171+
def _model_generate(**kwargs):
172172
if is_torch_npu_available():
173173
torch.npu.set_device(self.model.device)
174-
self.model.generate(*args, **kwargs)
174+
template.generate(self.model, **kwargs)
175175

176176
generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model)
177177
thread = Thread(target=_model_generate, kwargs=generate_kwargs)
@@ -269,7 +269,7 @@ def _infer_full(self,
269269
num_prompt_tokens = self._get_num_tokens(inputs)
270270

271271
generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model)
272-
output = dict(self.model.generate(**generate_kwargs))
272+
output = dict(template.generate(self.model, **generate_kwargs))
273273
output.pop('past_key_values', None)
274274
batched_generate_ids = output['sequences']
275275
batched_generate_ids = template.get_generate_ids(batched_generate_ids, num_prompt_tokens)

swift/llm/model/model/mllm.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,6 @@ def to_dict(self, *args, **kwargs):
104104
if model is not None:
105105
model.config._to_dict = model.config.to_dict
106106
model.config.to_dict = MethodType(to_dict, model.config)
107-
from transformers import GenerationMixin
108-
model.generate = MethodType(GenerationMixin.generate, model)
109-
110-
if model and hasattr(model, '_old_forward'): # device_map
111-
device = model.lm_head.weight.device
112-
forward_origin = model._old_forward
113-
114-
def _forward(*args, **kwargs):
115-
if kwargs.get('append_last_valid_logits') is not None:
116-
kwargs['append_last_valid_logits'] = kwargs['append_last_valid_logits'].to(device)
117-
return forward_origin(*args, **kwargs)
118-
119-
model._old_forward = _forward
120-
model.forward_origin = forward_origin
121107

122108
return model, processor
123109

@@ -148,18 +134,8 @@ def get_model_tokenizer_molmo(model_dir: str,
148134
model_cls = get_class_from_dynamic_module('modeling_molmo.MolmoForCausalLM', model_dir)
149135
model_cls._no_split_modules = ['MolmoSequentialBlock']
150136
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
151-
if model:
152-
device = next(model.model.transformer.ff_out.parameters()).device
153-
forward_origin = model.model.forward
154-
155-
def _forward(*args, **kwargs):
156-
if kwargs.get('append_last_valid_logits') is not None:
157-
kwargs['append_last_valid_logits'] = kwargs['append_last_valid_logits'].to(device)
158-
return forward_origin(*args, **kwargs)
159-
160-
model.model.forward = _forward
161-
model.model.forward_origin = forward_origin
162137

138+
patch_output_clone(model.model.transformer.wte)
163139
return model, processor
164140

165141

swift/llm/model/model/qwen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ def get_model_tokenizer_qwen(model_dir: str,
2323
model_info: ModelInfo,
2424
model_kwargs: Dict[str, Any],
2525
load_model: bool = True,
26+
model_config=None,
2627
**kwargs):
27-
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
28+
if model_config is None:
29+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
2830
if model_config.torch_dtype is not None:
2931
k_true = dtype_mapping[model_config.torch_dtype]
3032
for k in dtype_mapping.values():

swift/llm/model/model_arch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
488488
MLLMModelArch.molmo,
489489
language_model='model.transformer',
490490
vision_tower='model.vision_backbone',
491-
))
491+
aligner='model.vision_backbone.image_projector'))
492492

493493
register_model_arch(
494494
MultiModelKeys(

swift/llm/model/register.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,17 +313,17 @@ def _check_torch_dtype(torch_dtype: torch.dtype):
313313

314314
def get_default_torch_dtype(torch_dtype: Optional[torch.dtype]):
315315
# torch_dtype: torch_dtype in config.json
316+
if torch_dtype is not None:
317+
return torch_dtype
318+
316319
if is_torch_cuda_available() or is_torch_npu_available():
317320
if is_torch_bf16_gpu_available():
318-
if torch_dtype in {torch.float16, torch.bfloat16}:
319-
res = torch_dtype
320-
else:
321-
res = torch.bfloat16
321+
return torch.bfloat16
322322
else:
323-
res = torch.float16
323+
return torch.float16
324324
else:
325325
# cpu
326-
res = torch.float32
326+
return torch.float32
327327
return res
328328

329329

swift/llm/template/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ def decode(self, generate_ids: List[int], is_finished: bool = True, tokenizer_kw
232232
tokenizer_kwargs = tokenizer_kwargs or {}
233233
return self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs)
234234

235+
def generate(self, model, *args, **kwargs):
236+
return model.generate(*args, **kwargs)
237+
235238
def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any:
236239
# Do not print template_meta.suffix[-1] and eos_token.
237240
# However, other stop_words will be printed.
Lines changed: 34 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,58 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Literal, Optional
33

44
import torch
55

66
from ..base import Template
77
from ..constant import MLLMTemplateType
88
from ..register import TemplateMeta, register_template
99
from ..template_inputs import StdTemplateInputs
10-
from ..utils import findall
10+
from ..utils import Context, findall
1111

1212

1313
class MolmoTemplate(Template):
14-
system = None
15-
use_model = True
16-
image_placeholder = ['<|image|>']
17-
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
18-
DEFAULT_IM_START_TOKEN = '<im_start>'
19-
DEFAULT_IM_END_TOKEN = '<im_end>'
20-
DEFAULT_IM_COL_TOKEN = '<im_col>'
2114

22-
def __init__(self, *args, **kwargs):
23-
Template.__init__(self, *args, **kwargs)
24-
self.processor_kwargs = {
25-
'images_kwargs': {
26-
'max_crops': 12,
27-
'overlap_margins': [4, 4],
28-
'base_image_input_size': [336, 336],
29-
'image_token_length_w': 12,
30-
'image_token_length_h': 12,
31-
'image_patch_size': 14,
32-
'image_padding_mask': True,
33-
},
34-
'text_kwargs': {
35-
'style': 'long_caption',
36-
'system_prompt': 'none',
37-
'message_format': 'role',
38-
'always_start_with_space': True,
39-
'sequence_length': 1536,
40-
'padding': False,
41-
}
42-
}
15+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
16+
inputs: StdTemplateInputs) -> List[Context]:
17+
return []
4318

4419
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
4520
encoded = super()._encode(inputs)
4621
# image
47-
raw_image = inputs.images
48-
res = {}
22+
images_inputs = self.processor.process(images=inputs.images or None, text='')
23+
images_input_ids = images_inputs.pop('input_ids').tolist()
24+
user_token = self._tokenize(' User')
25+
assert len(user_token) == 1
26+
idx = findall(images_input_ids, user_token[0])
27+
assert len(idx) == 1
4928
labels = encoded['labels']
50-
if raw_image:
51-
image_id = self.tokenizer.convert_tokens_to_ids(self.image_placeholder)
52-
idx_list = findall(encoded['input_ids'], image_id)
53-
res = self._process_images(raw_image, encoded['input_ids'], idx_list, labels)
54-
import numpy as np
55-
if 'image_input_idx' in res:
56-
# Shift patch mapping up by one since we added BOS
57-
image_input_idx = res['image_input_idx']
58-
res['image_input_idx'] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
59-
encoded['input_ids'] = res.pop('input_ids').tolist()
60-
if labels:
61-
encoded['labels'] = [-100] + res.pop('labels') # add one label for BOS
62-
63-
for k, v in res.items():
64-
res[k] = torch.from_numpy(v).unsqueeze(0)
65-
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
66-
encoded['input_ids'] = [bos] + encoded['input_ids']
67-
res.update({'input_ids': encoded['input_ids']})
68-
# prepare meta inputs
69-
encoded.update(self.prepare_meta_inputs(res))
70-
29+
encoded['input_ids'] = images_input_ids[:idx[0]] + encoded['input_ids']
30+
if labels:
31+
encoded['labels'] = [-100] * idx[0] + labels
32+
if 'images' in images_inputs:
33+
images_inputs['images'] = images_inputs['images'].to(self.config.torch_dtype)
34+
encoded.update(images_inputs)
7135
return encoded
7236

73-
def _process_images(self, images: List, tokens: List, idx_list: List = None, labels: List = None) -> torch.Tensor:
74-
from PIL.Image import Image
75-
import numpy as np
76-
if images is not None:
77-
image_arrays = []
78-
for image in images:
79-
if isinstance(image, Image):
80-
image = image.convert('RGB')
81-
image_arrays.append(np.array(image))
82-
else:
83-
assert len(image.shape) == 3 and image.shape[-1] == 3
84-
image_arrays.append(image.astype(np.uint8))
85-
images = image_arrays
86-
# For now only support inserting images at the start
87-
if idx_list is None:
88-
idx_list = [-1] * len(images)
89-
image_patch_token_id = self.processor.special_token_ids[self.DEFAULT_IMAGE_PATCH_TOKEN]
90-
image_col_token_id = self.processor.special_token_ids[self.DEFAULT_IM_COL_TOKEN]
91-
image_start_token_id = self.processor.special_token_ids[self.DEFAULT_IM_START_TOKEN]
92-
image_end_token_id = self.processor.special_token_ids[self.DEFAULT_IM_END_TOKEN]
93-
sequence_length = self.processor_kwargs['text_kwargs']['sequence_length']
94-
res = self.processor.image_processor.multimodal_preprocess(
95-
images=images,
96-
image_idx=idx_list,
97-
tokens=np.asarray(tokens).astype(np.int32),
98-
sequence_length=sequence_length,
99-
image_patch_token_id=image_patch_token_id,
100-
image_col_token_id=image_col_token_id,
101-
image_start_token_id=image_start_token_id,
102-
image_end_token_id=image_end_token_id,
103-
**self.processor_kwargs['images_kwargs'])
104-
if labels is not None:
105-
new_labels = []
106-
cur_idx = 0
107-
for input_id in res['input_ids']:
108-
if input_id in (image_start_token_id, image_end_token_id, image_col_token_id, image_patch_token_id):
109-
new_labels.append(-100)
110-
if tokens[cur_idx] == self.tokenizer.convert_tokens_to_ids(self.image_placeholder)[0]:
111-
cur_idx += 1
112-
else:
113-
new_labels.append(labels[cur_idx])
114-
cur_idx += 1
115-
res['labels'] = new_labels
116-
return res
117-
118-
def prepare_meta_inputs(self, data: Any) -> Dict[str, Any]:
119-
120-
# prepare batch inputs
121-
input_ids = torch.tensor(data['input_ids']).unsqueeze(0)
122-
batch_size, seq_len = input_ids.shape
123-
attention_mask = None
124-
mask_len = seq_len
125-
max_new_tokens = None
126-
if not self.is_training:
127-
generation_config = self.model.generation_config
128-
max_new_tokens = generation_config.max_new_tokens
129-
if not max_new_tokens:
130-
max_new_tokens = 0
131-
mask_len = mask_len + max_new_tokens if self.model.config.use_position_ids else mask_len
132-
position_ids: Optional[torch.Tensor] = None
133-
append_last_valid_logits: Optional[torch.Tensor] = None
134-
if self.model.config.use_position_ids and attention_mask is None:
135-
attention_mask = input_ids != -1
136-
position_ids = torch.clamp(torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0)
137-
append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
138-
if max_new_tokens:
139-
attention_mask = torch.cat(
140-
[attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
141-
dim=1,
142-
)
143-
if attention_mask is not None:
144-
assert attention_mask.shape == (batch_size, mask_len)
145-
if self.is_training:
146-
# no batch_size before data_collator
147-
attention_mask = attention_mask.squeeze(0)
148-
position_ids = position_ids.squeeze(0)
149-
data.update({
150-
'attention_mask': attention_mask,
151-
'position_ids': position_ids,
152-
'append_last_valid_logits': append_last_valid_logits,
153-
})
154-
if 'images' in data:
155-
data['images'] = data['images'].to(self.model.dtype)
156-
return data
37+
def generate(self, model, **kwargs):
38+
kwargs.pop('attention_mask', None)
39+
generation_config = kwargs.pop('generation_config')
40+
batch = {
41+
k: kwargs.pop(k, None)
42+
for k in ['input_ids', 'attention_mask', 'images', 'image_input_idx', 'image_masks']
43+
}
44+
return model.generate_from_batch(batch, generation_config, **kwargs)
15745

15846
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
159-
res = super().data_collator(batch, padding_to=padding_to)
47+
res = super()._data_collator(batch, padding_to=padding_to)
16048
# prepare batchfy inputs
161-
keys = ['images', 'image_input_idx', 'image_masks', 'append_last_valid_logits']
49+
keys = ['images', 'image_input_idx', 'image_masks']
50+
images_res = self.fetch_inputs(batch, keys)
16251
for key in keys:
163-
batch_input = [b[key] for b in batch if b.get(key) is not None]
164-
res[key] = torch.concat(batch_input)
165-
52+
val = images_res.get(key)
53+
if val:
54+
images_res[key] = torch.stack(val)
55+
res.update(images_res)
16656
return res
16757

16858

@@ -171,8 +61,8 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
17161
MLLMTemplateType.molmo,
17262
prefix=[],
17363
prompt=[' User: {{QUERY}} Assistant:'],
174-
chat_sep=['<|endoftext|>'],
64+
chat_sep=None,
17565
suffix=['<|endoftext|>'],
17666
template_cls=MolmoTemplate,
177-
placeholder_tokens=['<|image|>'],
67+
placeholder_tokens=['<im_patch>'],
17868
))

swift/llm/template/template/mplug.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
4343
return res
4444

4545
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
46-
res = super().data_collator(batch, padding_to=padding_to)
46+
res = super()._data_collator(batch, padding_to=padding_to)
4747
images = [b['images'] for b in batch if 'images' in b]
4848
if images:
4949
res['images'] = torch.concat(images)

tests/test_align/test_template/test_vision.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def test_deepseek_vl2():
157157

158158

159159
def test_mplug_owl2():
160-
pass
160+
# pt_engine = PtEngine('iic/mPLUG-Owl2')
161+
pt_engine = PtEngine('iic/mPLUG-Owl2.1')
162+
_infer_model(pt_engine, messages=[{'role': 'user', 'content': '<image>这是什么'}])
161163

162164

163165
def test_mplug_owl3():
@@ -218,6 +220,30 @@ def test_megrez_omni():
218220
'没有阴影或明亮的阳光表明这不是正午时分,也没有雨滴或雪花的迹象,这可能意味着不是下雨或下雪的日子。')
219221

220222

223+
def test_molmo():
224+
# pt_engine = PtEngine('LLM-Research/Molmo-7B-O-0924')
225+
pt_engine = PtEngine('LLM-Research/Molmo-7B-D-0924')
226+
_infer_model(pt_engine)
227+
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<image>这是什么'}])
228+
assert response == (
229+
' This is a close-up photograph of a young kitten. '
230+
'The kitten has striking blue eyes and a mix of white and black fur, '
231+
'with distinctive black stripes on its head and face. '
232+
"It's looking directly at the camera with an alert and curious expression. "
233+
"The kitten's fur appears soft and fluffy, and its pink nose and white whiskers are clearly visible. "
234+
'The background is blurred, which emphasizes the kitten as the main subject of the image.')
235+
236+
237+
def test_molmoe():
238+
pt_engine = PtEngine('LLM-Research/MolmoE-1B-0924')
239+
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<image>这是什么'}])
240+
assert response == (" This is a close-up photograph of a kitten's face. The kitten has striking blue eyes and "
241+
"a mix of white, black, and brown fur. It's looking directly at the camera with an adorable "
242+
"expression, its ears perked up and whiskers visible. The image captures the kitten's cute "
243+
'features in sharp detail, while the background is blurred, creating a soft, out-of-focus '
244+
"effect that emphasizes the young feline's charm.")
245+
246+
221247
if __name__ == '__main__':
222248
from swift.llm import PtEngine, RequestConfig, get_template
223249
from swift.utils import get_logger, seed_everything
@@ -247,4 +273,7 @@ def test_megrez_omni():
247273
# test_mplug_owl3()
248274
# test_xcomposer2_5()
249275
# test_megrez_omni()
250-
test_qvq()
276+
# test_qvq()
277+
# test_mplug_owl2()
278+
# test_molmo()
279+
test_molmoe()

0 commit comments

Comments
 (0)