| 
26 | 26 | from transformers.utils.versions import require_version  | 
27 | 27 | 
 
  | 
28 | 28 | from swift import get_logger  | 
29 |  | -from swift.utils import is_dist, is_local_master, use_torchacc  | 
 | 29 | +from swift.utils import (get_dist_setting, is_dist, is_local_master,  | 
 | 30 | +                         use_torchacc)  | 
30 | 31 | from .template import TemplateType  | 
31 | 32 | from .utils import get_max_model_len  | 
32 | 33 | 
 
  | 
@@ -2206,7 +2207,27 @@ def get_model_tokenizer_qwen_chat(*args, **kwargs):  | 
2206 | 2207 |     return model, tokenizer  | 
2207 | 2208 | 
 
  | 
2208 | 2209 | 
 
  | 
 | 2210 | +def _qwen_vl_visual_block_forward(  | 
 | 2211 | +    self,  | 
 | 2212 | +    q_x: torch.Tensor,  | 
 | 2213 | +    k_x: Optional[torch.Tensor] = None,  | 
 | 2214 | +    v_x: Optional[torch.Tensor] = None,  | 
 | 2215 | +    attn_mask: Optional[torch.Tensor] = None,  | 
 | 2216 | +):  | 
 | 2217 | +    k_x = self.ln_1_kv(k_x) if hasattr(self,  | 
 | 2218 | +                                       'ln_1_kv') and k_x is not None else None  | 
 | 2219 | +    v_x = self.ln_1_kv(v_x) if hasattr(self,  | 
 | 2220 | +                                       'ln_1_kv') and v_x is not None else None  | 
 | 2221 | + | 
 | 2222 | +    x = q_x + self.attention(  | 
 | 2223 | +        q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)  | 
 | 2224 | +    z = self.mlp(self.ln_2(x))  | 
 | 2225 | +    x = x.to(z.device) + z  # FIX  | 
 | 2226 | +    return x  | 
 | 2227 | + | 
 | 2228 | + | 
2209 | 2229 | def fix_qwen_inplace_bug(model) -> None:  | 
 | 2230 | +    # qwen-vl, qwen-audio  | 
2210 | 2231 |     first_drop = model.transformer.drop  | 
2211 | 2232 |     if first_drop.p == 0.:  | 
2212 | 2233 |         # fix in-place operation bug  | 
@@ -2271,12 +2292,27 @@ def get_model_tokenizer_qwen_vl(model_dir: str,  | 
2271 | 2292 |     if not hasattr(tokenizer_cls, '_old_decode'):  # avoid double patching  | 
2272 | 2293 |         tokenizer_cls._old_decode = tokenizer_cls._decode  | 
2273 | 2294 |         tokenizer_cls._decode = _qwen_vl_audio_decode  | 
 | 2295 | +    # fix device_map is 4  | 
 | 2296 | +    n_gpu = torch.cuda.device_count()  | 
 | 2297 | +    local_world_size = get_dist_setting()[3]  | 
 | 2298 | +    if n_gpu // local_world_size >= 4:  | 
 | 2299 | +        visual_block_cls = get_class_from_dynamic_module(  | 
 | 2300 | +            'visual.VisualAttentionBlock', model_dir)  | 
 | 2301 | +        if not hasattr(visual_block_cls,  | 
 | 2302 | +                       '__old_forward'):  # avoid double patching  | 
 | 2303 | +            visual_block_cls.__old_forward = visual_block_cls.forward  | 
 | 2304 | +            visual_block_cls.forward = _qwen_vl_visual_block_forward  | 
 | 2305 | + | 
2274 | 2306 |     kwargs['tokenizer'] = tokenizer_cls.from_pretrained(  | 
2275 | 2307 |         model_dir, trust_remote_code=True)  | 
2276 | 2308 |     model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,  | 
2277 | 2309 |                                          load_model, **kwargs)  | 
2278 | 2310 |     if model is not None:  | 
2279 | 2311 |         fix_qwen_inplace_bug(model)  | 
 | 2312 | +        # fix device_map is 4  | 
 | 2313 | +        if n_gpu // local_world_size >= 4:  | 
 | 2314 | +            model.transformer.visual.proj.data = model.transformer.visual.proj.to(  | 
 | 2315 | +                model.transformer.visual.ln_post.bias.device)  | 
2280 | 2316 | 
 
  | 
2281 | 2317 |     return model, tokenizer  | 
2282 | 2318 | 
 
  | 
 | 
0 commit comments