Skip to content

Commit c222140

Browse files
authored
fix minicpm-v (#1562)
1 parent bd3ef2d commit c222140

File tree

2 files changed

+48
-41
lines changed

2 files changed

+48
-41
lines changed

swift/llm/utils/model.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5572,6 +5572,7 @@ def _new_forward(*args, **kwargs) -> Tensor:
55725572
LoRATM.llama,
55735573
TemplateType.minicpm_v,
55745574
support_flash_attn=True,
5575+
requires=['timm', 'transformers<4.42'],
55755576
tags=['multi-modal', 'vision'],
55765577
hf_model_id='openbmb/MiniCPM-V')
55775578
@register_model(
@@ -5580,44 +5581,53 @@ def _new_forward(*args, **kwargs) -> Tensor:
55805581
LoRATM.llama,
55815582
TemplateType.minicpm_v,
55825583
support_flash_attn=True,
5583-
requires=['timm'],
5584+
requires=['timm', 'transformers<4.42'],
55845585
tags=['multi-modal', 'vision'],
55855586
hf_model_id='openbmb/MiniCPM-V-2')
5586-
@register_model(
5587-
ModelType.minicpm_v_v2_5_chat,
5588-
'OpenBMB/MiniCPM-Llama3-V-2_5',
5589-
LoRATM.minicpm_llama,
5590-
TemplateType.minicpm_v_v2_5,
5591-
support_flash_attn=True,
5592-
support_lmdeploy=True,
5593-
requires=['timm'],
5594-
placeholder_tokens=['<unk>'],
5595-
function_kwargs={'patching_embedding': True},
5596-
tags=['multi-modal', 'vision'],
5597-
hf_model_id='openbmb/MiniCPM-Llama3-V-2_5')
55985587
def get_model_tokenizer_minicpm_v(model_dir: str,
55995588
torch_dtype: Dtype,
56005589
model_kwargs: Dict[str, Any],
56015590
load_model: bool = True,
56025591
**kwargs):
5603-
patching_embedding = kwargs.pop('patching_embedding', False)
56045592
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
56055593
if load_model:
56065594
model.resampler.to(torch_dtype) # fix float32
56075595
_patch_minicpm_v_device_map(model)
56085596
func_list = ['generate', 'get_input_embeddings', 'forward']
56095597
_use_submodel_func(model, 'llm', func_list)
5610-
if patching_embedding:
5611-
embedding = model.get_input_embeddings()
5612-
if not hasattr(embedding, '__old_forward'): # Avoid double patching
5613-
old_forward = embedding.forward
5598+
return model, tokenizer
5599+
5600+
5601+
@register_model(
5602+
ModelType.minicpm_v_v2_5_chat,
5603+
'OpenBMB/MiniCPM-Llama3-V-2_5',
5604+
LoRATM.minicpm_llama,
5605+
TemplateType.minicpm_v_v2_5,
5606+
support_flash_attn=True,
5607+
requires=['timm', 'transformers>=4.36'],
5608+
placeholder_tokens=['<unk>'],
5609+
tags=['multi-modal', 'vision'],
5610+
hf_model_id='openbmb/MiniCPM-Llama3-V-2_5')
5611+
def get_model_tokenizer_minicpm_v_2_5(model_dir: str,
5612+
torch_dtype: Dtype,
5613+
model_kwargs: Dict[str, Any],
5614+
load_model: bool = True,
5615+
**kwargs):
5616+
from transformers import AutoProcessor
5617+
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
5618+
model, tokenizer = get_model_tokenizer_minicpm_v(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
5619+
tokenizer.processor = processor
5620+
if load_model:
5621+
embedding = model.get_input_embeddings()
5622+
if not hasattr(embedding, '__old_forward'): # Avoid double patching
5623+
old_forward = embedding.forward
56145624

5615-
@wraps(old_forward)
5616-
def _new_forward(*args, **kwargs):
5617-
return old_forward(*args, **kwargs).requires_grad_(True).clone()
5625+
@wraps(old_forward)
5626+
def _new_forward(*args, **kwargs):
5627+
return old_forward(*args, **kwargs).requires_grad_(True).clone()
56185628

5619-
embedding.__old_forward = old_forward
5620-
embedding.forward = _new_forward
5629+
embedding.__old_forward = old_forward
5630+
embedding.forward = _new_forward
56215631
return model, tokenizer
56225632

56235633

swift/llm/utils/template.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,10 +1878,10 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
18781878
if len(inputs) == 0:
18791879
return inputs, {}
18801880
images = example.get('images', [])
1881-
image_processor = self.tokenizer.processor.image_processor
18821881
if self._is_vllm:
18831882
images = self._prepare_vllm_images(images)
18841883
if images:
1884+
image_processor = self.tokenizer.processor.image_processor
18851885
image_inputs = image_processor(images, return_tensors='pt').to(self.model.dtype)
18861886
inputs['pixel_values'] = image_inputs['pixel_values']
18871887
if 'image_sizes' in image_inputs:
@@ -2470,7 +2470,16 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
24702470
tgt_sizes = None
24712471
slice_mode = getattr(config, 'slice_mode', False)
24722472
if slice_mode:
2473-
images, placeholder = self.model.get_slice_image_placeholder(image, self.tokenizer)
2473+
if self.is_v2_5:
2474+
from .utils import to_device
2475+
image_processor = self.tokenizer.processor.image_processor
2476+
image_inputs = image_processor(images, return_tensors='pt').to(self.model.dtype)
2477+
placeholder = image_processor.get_slice_image_placeholder(image_inputs.image_sizes[0][0])
2478+
pixel_values = to_device(image_inputs['pixel_values'], self.model.device)
2479+
tgt_sizes = image_inputs['tgt_sizes']
2480+
else:
2481+
images, placeholder = self.model.get_slice_image_placeholder(image, self.tokenizer)
2482+
pixel_values = [[self.model.transform(img).to(device=self.model.device) for img in images]]
24742483
placeholder += '\n'
24752484
placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
24762485
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
@@ -2485,33 +2494,21 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
24852494
torch.hstack(
24862495
[image_start_idx[:valid_image_nums].unsqueeze(-1), image_end_idx[:valid_image_nums].unsqueeze(-1)])
24872496
]
2488-
if self.is_v2_5:
2489-
pixel_values = []
2490-
tgt_sizes = []
2491-
config = self.model.config
2492-
for image in images:
2493-
image = self.model.transform(image).to(device=self.model.device)
2494-
H, W = image.shape[1:]
2495-
pixel_values.append(self.model.reshape_by_patch(image))
2496-
tgt_sizes.append(torch.Tensor([H // config.patch_size, W // config.patch_size]).type(torch.int32))
2497-
tgt_sizes = torch.vstack(tgt_sizes)
2498-
else:
2499-
pixel_values = [self.model.transform(img).to(device=self.model.device) for img in images]
25002497
else:
25012498
placeholder = '<image>' + '<unk>' * config.query_num + '</image>\n'
25022499
placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
25032500
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
25042501
if labels is not None:
25052502
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
25062503
image_bound = [torch.tensor([[idx, idx + config.query_num]])]
2507-
pixel_values = [self.model.transform(image).to(device=self.model.device)]
2504+
pixel_values = [[self.model.transform(image).to(device=self.model.device)]]
25082505
data = {
25092506
'input_ids': torch.tensor(input_ids)[None].to(device=self.model.device),
25102507
'image_bound': image_bound,
2511-
'pixel_values': [pixel_values]
2508+
'pixel_values': pixel_values
25122509
}
2513-
if tgt_sizes is not None:
2514-
data['tgt_sizes'] = [tgt_sizes]
2510+
if tgt_sizes is not None: # v2.5
2511+
data['tgt_sizes'] = tgt_sizes
25152512
inputs_embeds, _ = self.model.get_vllm_embedding(data)
25162513
inputs_embeds = inputs_embeds.detach()
25172514
inputs['input_ids'] = input_ids

0 commit comments

Comments
 (0)