Skip to content

Commit 1156a57

Browse files
support mPLUG-Owl3 241101 (#2515)
1 parent 80db4a6 commit 1156a57

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

swift/llm/utils/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,7 @@ class ModelType:
563563
mplug_owl3_1b_chat = 'mplug-owl3-1b-chat'
564564
mplug_owl3_2b_chat = 'mplug-owl3-2b-chat'
565565
mplug_owl3_7b_chat = 'mplug-owl3-7b-chat'
566+
mplug_owl3v_7b_chat = 'mplug-owl3v-7b-chat'
566567
# yuan
567568
yuan2_2b_instruct = 'yuan2-2b-instruct'
568569
yuan2_2b_janus_instruct = 'yuan2-2b-janus-instruct'
@@ -3070,6 +3071,15 @@ def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx
30703071
support_flash_attn=True,
30713072
tags=['multi-modal', 'vision', 'video'],
30723073
hf_model_id='mPLUG/mPLUG-Owl3-7B-240728')
3074+
@register_model(
3075+
ModelType.mplug_owl3v_7b_chat,
3076+
'iic/mPLUG-Owl3-7B-241101',
3077+
LoRATM.mplug_owl3,
3078+
TemplateType.mplug_owl3v,
3079+
requires=['transformers>=4.36', 'icecream'], # decord
3080+
support_flash_attn=True,
3081+
tags=['multi-modal', 'vision', 'video'],
3082+
hf_model_id='mPLUG/mPLUG-Owl3-7B-241101')
30733083
def get_model_tokenizer_mplug_owl3(model_dir: str,
30743084
torch_dtype: torch.dtype,
30753085
model_kwargs: Dict[str, Any],

swift/llm/utils/template.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class TemplateType:
134134
paligemma = 'paligemma'
135135
mplug_owl2 = 'mplug-owl2'
136136
mplug_owl3 = 'mplug_owl3'
137+
mplug_owl3v = 'mplug_owl3v'
137138
wizardlm2_awq = 'wizardlm2-awq'
138139
wizardlm2 = 'wizardlm2'
139140
atom = 'atom'
@@ -4004,7 +4005,69 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
40044005
return res
40054006

40064007

4008+
class mPlugOwl3vTemplate(mPlugOwl3Template):
4009+
system = None
4010+
4011+
def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
4012+
inputs, _ = super(mPlugOwl3Template, self)._encode(example)
4013+
if len(inputs) == 0:
4014+
return inputs, {}
4015+
images = example['images']
4016+
videos = example['videos']
4017+
cut_enable = not videos
4018+
input_ids = inputs['input_ids']
4019+
labels = inputs['labels']
4020+
idx_list = _findall(input_ids, -100)
4021+
processor = self.tokenizer.processor
4022+
inputs = {'_data': {}}
4023+
if images:
4024+
image_inputs = processor.image_processor(images, cut_enable=cut_enable, return_tensors='pt')
4025+
added_tokens_len = 0
4026+
cut_shapes = image_inputs['cut_shape'] or [None] * 2 * len(idx_list)
4027+
image_token_list = self.tokenizer.encode('<|image|>', add_special_tokens=False)
4028+
for idx, cut_shape in zip(idx_list, cut_shapes[::2]):
4029+
if cut_shape:
4030+
token_list = self._get_image_token_list(cut_shape)
4031+
else:
4032+
token_list = image_token_list
4033+
input_ids = input_ids[:idx + added_tokens_len] + token_list + input_ids[added_tokens_len + idx + 1:]
4034+
if labels:
4035+
labels = labels[:idx + added_tokens_len] + [-100] * len(token_list) + labels[added_tokens_len + idx
4036+
+ 1:]
4037+
added_tokens_len += len(token_list) - 1
4038+
image_token_idx = torch.tensor(_findall(input_ids, image_token_list))
4039+
4040+
inputs['_data'].update({
4041+
'pixel_values': image_inputs['pixel_values'],
4042+
'media_offset': image_token_idx,
4043+
})
4044+
inputs['_data']['input_ids'] = input_ids
4045+
inputs['labels'] = labels
4046+
return inputs, {}
4047+
4048+
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
4049+
if 'pixel_values' in data:
4050+
pixel_values = data.pop('pixel_values')
4051+
data['image_embeds'] = model.forward_image(pixel_values)
4052+
return data
4053+
4054+
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
4055+
res = super(mPlugOwl3Template, self).data_collator(batch, padding_to)
4056+
image_embeds = [b['image_embeds'] for b in batch if 'image_embeds' in b]
4057+
if image_embeds:
4058+
res['image_embeds'] = torch.concat(image_embeds)
4059+
media_offset = []
4060+
4061+
for bi, b in enumerate(batch):
4062+
media_offset.append(b.get('media_offset', torch.tensor([]).long()))
4063+
4064+
if media_offset:
4065+
res['media_offset'] = media_offset
4066+
return res
4067+
4068+
40074069
register_template(TemplateType.mplug_owl3, mPlugOwl3Template(), use_model=True, lazy_tokenize=True)
4070+
register_template(TemplateType.mplug_owl3v, mPlugOwl3vTemplate(), use_model=True, lazy_tokenize=True)
40084071

40094072
register_template(TemplateType.wizardlm2_awq,
40104073
Template(['{{SYSTEM}}'], ['User:\n{{QUERY}}\n\nAssistant:\n'], ['\n\n'], ['</s>']))

0 commit comments

Comments
 (0)