Skip to content

Commit 54233a2

Browse files
authored
Fix minicpm device map (#978)
1 parent 8c841d4 commit 54233a2

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

docs/source/Multi-Modal/minicpm-v-2.5最佳实践.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
## 环境准备
1212
```shell
13-
pip install 'ms-swift[llm]' -U
13+
git clone https://github.com/modelscope/swift.git
14+
cd swift
15+
pip install -e '.[llm]'
1416
```
1517
模型链接:
1618
- minicpm-v-v2_5-chat: [https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary](https://modelscope.cn/models/OpenBMB/MiniCPM-Llama3-V-2_5/summary)

docs/source/Multi-Modal/minicpm-v-2最佳实践.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
## 环境准备
1212
```shell
13-
pip install 'ms-swift[llm]' -U
13+
git clone https://github.com/modelscope/swift.git
14+
cd swift
15+
pip install -e '.[llm]'
1416
```
1517

1618
## 推理

swift/llm/utils/model.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3870,6 +3870,52 @@ def get_model_tokenizer_yi_vl(model_dir: str,
38703870
return model, tokenizer
38713871

38723872

3873+
def _patch_minicpm_v_device_map(model) -> None:
3874+
if not hasattr(model, 'hf_device_map') or len(model.hf_device_map.values()) == 1:
3875+
return
3876+
if hasattr(model.llm, '__old_forward'):
3877+
# avoid double patching
3878+
return
3879+
device = list(model.hf_device_map.values())[0]
3880+
if hasattr(model, 'get_vision_embedding'): # minicpm-v-v2-chat
3881+
_old_get_vision_embedding = model.get_vision_embedding
3882+
3883+
def _get_vision_embedding(pixel_values):
3884+
if len(pixel_values) == 0:
3885+
return _old_get_vision_embedding(pixel_values)
3886+
output = _old_get_vision_embedding(pixel_values)
3887+
return output.to(device=device)
3888+
3889+
model._old_get_vision_embedding = _old_get_vision_embedding
3890+
model.get_vision_embedding = _get_vision_embedding
3891+
3892+
if hasattr(model, 'resampler'): # minicpm-v-v2_5-chat
3893+
__old_resampler_forward = model.resampler.forward
3894+
3895+
def _new_resampler_forward(*args, **kwargs) -> Tensor:
3896+
output = __old_resampler_forward(*args, **kwargs)
3897+
return output.to(device=device)
3898+
3899+
model.resampler.forward = _new_resampler_forward
3900+
3901+
__old_forward = model.llm.forward
3902+
3903+
def _new_forward(*args, **kwargs) -> Tensor:
3904+
inputs = kwargs.get('inputs_embeds')
3905+
if inputs is None:
3906+
inputs = kwargs.get('input_ids')
3907+
device = inputs.device
3908+
output = __old_forward(*args, **kwargs)
3909+
if output.logits is not None:
3910+
output.logits = output.logits.to(device)
3911+
if output.loss is not None:
3912+
output.loss = output.loss.to(device)
3913+
return output
3914+
3915+
model.llm.forward = _new_forward
3916+
model.llm.__old_forward = __old_forward
3917+
3918+
38733919
@register_model(
38743920
ModelType.minicpm_v_3b_chat,
38753921
'OpenBMB/MiniCPM-V',
@@ -3904,6 +3950,7 @@ def get_model_tokenizer_minicpm_v(model_dir: str,
39043950
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
39053951
if load_model:
39063952
model.resampler.to(torch_dtype) # fix float32
3953+
_patch_minicpm_v_device_map(model)
39073954
func_list = ['generate', 'get_input_embeddings', 'forward']
39083955
_use_submodel_func(model, 'llm', func_list)
39093956
if patching_embedding:

0 commit comments

Comments
 (0)