Skip to content

Commit f0edb83

Browse files
authored
fix xcomposer device_map (#844)
1 parent 2accb9b commit f0edb83

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

swift/llm/utils/model.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,7 @@ def get_model_tokenizer_mamba(model_dir: str,
938938
LoRATM.cogagent,
939939
TemplateType.cogagent_chat,
940940
support_gradient_checkpointing=False,
941+
requires=['timm'],
941942
tags=['multi-modal', 'vision'],
942943
hf_model_id='THUDM/cogagent-chat-hf')
943944
@register_model(
@@ -946,6 +947,7 @@ def get_model_tokenizer_mamba(model_dir: str,
946947
LoRATM.cogagent,
947948
TemplateType.cogagent_instruct,
948949
support_gradient_checkpointing=False,
950+
requires=['timm'],
949951
tags=['multi-modal', 'vision'],
950952
hf_model_id='THUDM/cogagent-vqa-hf')
951953
def get_model_tokenizer_cogagent(model_dir: str,
@@ -2610,9 +2612,31 @@ def get_model_tokenizer_internlm_xcomposer2(model_dir: str,
26102612
if getattr(tokenizer.__class__.eos_token_id, 'fset', None) is None:
26112613
del tokenizer.__class__.eos_token_id
26122614
tokenizer.eos_token = eos_token
2613-
if model is not None and use_flash_attn:
2614-
# fix AttributeError: no attribute 'attention_dropout'
2615-
model.model.layers[0].attention.__class__.attention_dropout = 0.
2615+
if model is not None:
2616+
if use_flash_attn:
2617+
# fix AttributeError: no attribute 'attention_dropout'
2618+
model.model.layers[0].attention.__class__.attention_dropout = 0.
2619+
2620+
model_cls = model.__class__
2621+
if not hasattr(model_cls, '__old_encode_img'): # avoid double patching
2622+
model_cls.__old_encode_img = model_cls.encode_img
2623+
2624+
def _new_encode_img(self, image):
2625+
if image is None:
2626+
return None
2627+
if isinstance(image, str):
2628+
from PIL import Image
2629+
image = Image.open(image).convert('RGB')
2630+
image = self.vis_processor(image).unsqueeze(0).to(
2631+
self.device)
2632+
else:
2633+
assert isinstance(image, torch.Tensor)
2634+
2635+
img_embeds, atts_img, img_target = self.img2emb(image)
2636+
return img_embeds.to(device=self.device) # FIX device_map
2637+
2638+
model_cls.encode_img = _new_encode_img
2639+
26162640
return model, tokenizer
26172641

26182642

0 commit comments

Comments
 (0)