Skip to content

Commit 24e6b11

Browse files
authored
fix device_map 4 (qwen-vl) (#695)
1 parent 1dc2d65 commit 24e6b11

File tree

3 files changed

+59
-6
lines changed

3 files changed

+59
-6
lines changed

docs/source/Multi-Modal/minicpm-v最佳实践.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
# MiniCPM-V 最佳实践
3+
以下内容以`minicpm-v-3b-chat`为例, 如果你想要使用更新版本的 MiniCPM-V 多模态模型(v2), 你可以将`--model_type minicpm-v-3b-chat`切换成`--model_type minicpm-v-v2`.
34

45
## 目录
56
- [环境准备](#环境准备)
@@ -13,9 +14,14 @@
1314
pip install ms-swift[llm] -U
1415
```
1516

17+
模型链接:
18+
- minicpm-v-3b-chat: [https://modelscope.cn/models/OpenBMB/MiniCPM-V/summary](https://modelscope.cn/models/OpenBMB/MiniCPM-V/summary)
19+
- minicpm-v-v2: [https://modelscope.cn/models/OpenBMB/MiniCPM-V-2.0/summary](https://modelscope.cn/models/OpenBMB/MiniCPM-V-2.0/summary)
20+
21+
1622
## 推理
1723

18-
推理[minicpm-v-3b-chat](https://modelscope.cn/models/OpenBMB/MiniCPM-V/summary):
24+
推理minicpm-v-3b-chat:
1925
```shell
2026
# Experimental environment: A10, 3090, V100, ...
2127
# 10GB GPU memory

swift/llm/utils/model.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from transformers.utils.versions import require_version
2727

2828
from swift import get_logger
29-
from swift.utils import is_dist, is_local_master, use_torchacc
29+
from swift.utils import (get_dist_setting, is_dist, is_local_master,
30+
use_torchacc)
3031
from .template import TemplateType
3132
from .utils import get_max_model_len
3233

@@ -2206,7 +2207,27 @@ def get_model_tokenizer_qwen_chat(*args, **kwargs):
22062207
return model, tokenizer
22072208

22082209

2210+
def _qwen_vl_visual_block_forward(
2211+
self,
2212+
q_x: torch.Tensor,
2213+
k_x: Optional[torch.Tensor] = None,
2214+
v_x: Optional[torch.Tensor] = None,
2215+
attn_mask: Optional[torch.Tensor] = None,
2216+
):
2217+
k_x = self.ln_1_kv(k_x) if hasattr(self,
2218+
'ln_1_kv') and k_x is not None else None
2219+
v_x = self.ln_1_kv(v_x) if hasattr(self,
2220+
'ln_1_kv') and v_x is not None else None
2221+
2222+
x = q_x + self.attention(
2223+
q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
2224+
z = self.mlp(self.ln_2(x))
2225+
x = x.to(z.device) + z # FIX
2226+
return x
2227+
2228+
22092229
def fix_qwen_inplace_bug(model) -> None:
2230+
# qwen-vl, qwen-audio
22102231
first_drop = model.transformer.drop
22112232
if first_drop.p == 0.:
22122233
# fix in-place operation bug
@@ -2271,12 +2292,27 @@ def get_model_tokenizer_qwen_vl(model_dir: str,
22712292
if not hasattr(tokenizer_cls, '_old_decode'): # avoid double patching
22722293
tokenizer_cls._old_decode = tokenizer_cls._decode
22732294
tokenizer_cls._decode = _qwen_vl_audio_decode
2295+
# fix device_map is 4
2296+
n_gpu = torch.cuda.device_count()
2297+
local_world_size = get_dist_setting()[3]
2298+
if n_gpu // local_world_size >= 4:
2299+
visual_block_cls = get_class_from_dynamic_module(
2300+
'visual.VisualAttentionBlock', model_dir)
2301+
if not hasattr(visual_block_cls,
2302+
'__old_forward'): # avoid double patching
2303+
visual_block_cls.__old_forward = visual_block_cls.forward
2304+
visual_block_cls.forward = _qwen_vl_visual_block_forward
2305+
22742306
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(
22752307
model_dir, trust_remote_code=True)
22762308
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
22772309
load_model, **kwargs)
22782310
if model is not None:
22792311
fix_qwen_inplace_bug(model)
2312+
# fix device_map is 4
2313+
if n_gpu // local_world_size >= 4:
2314+
model.transformer.visual.proj.data = model.transformer.visual.proj.to(
2315+
model.transformer.visual.ln_post.bias.device)
22802316

22812317
return model, tokenizer
22822318

swift/llm/utils/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,12 +399,23 @@ def find_all_linears(model: Module, quantization_bit: int,
399399
if 'aqlm' in model_type:
400400
from aqlm import QuantizedLinear
401401
linear_cls.append(QuantizedLinear)
402+
403+
# The content of target_module_names cannot exist in inner_nodes.
404+
# O(n^2logn), n represents the number of nodes, n<1000.
405+
inner_nodes = set()
406+
for name, module in model.named_modules():
407+
if not isinstance(module, tuple(linear_cls)):
408+
inner_nodes.add(name)
402409
target_module_names = set()
403410
for name, module in model.named_modules():
404-
if isinstance(module, tuple(linear_cls)):
405-
module_name = '.'.join(name.split('.')[-2:])
406-
if head_module_name not in module_name:
407-
target_module_names.add(module_name)
411+
if isinstance(module,
412+
tuple(linear_cls)) and head_module_name not in name:
413+
module_name_list = name.split('.')
414+
module_name = module_name_list.pop()
415+
for inner_node in inner_nodes:
416+
while inner_node.endswith(module_name):
417+
module_name = f'{module_name_list.pop()}.{module_name}'
418+
target_module_names.add(module_name)
408419
return list(target_module_names)
409420

410421

0 commit comments

Comments
 (0)