Skip to content

Commit a2459b8

Browse files
authored
update get_model_tokenizer_with_flash_attn (#3337)
1 parent 463091c commit a2459b8

File tree

7 files changed

+88
-20
lines changed

7 files changed

+88
-20
lines changed

docs/source/Customization/自定义数据集.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ label代表两个句子的相似度, loss使用`cosine_similarity`
121121
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "<image>图片中是什么,<video>视频中是什么"}, {"role": "assistant", "content": "图片中是一个大象,视频中是一只小狗在草地上奔跑"}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
122122
```
123123

124-
多模态模型的RLHF和序列分类的数据格式可以参考纯文本大模型的格式。
124+
多模态模型的RLHF和序列分类的数据格式可以参考纯文本大模型的格式,并在此基础上增加`images`等字段
125125

126126
#### grounding
127127

docs/source_en/Customization/Custom-dataset.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ Supervised Fine-tuning:
123123
{"messages": [{"role": "system", "content": "You are a helpful and harmless assistant."}, {"role": "user", "content": "<image>What is in the image, <video>What is in the video?"}, {"role": "assistant", "content": "The image shows an elephant, and the video shows a puppy running on the grass."}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
124124
```
125125

126-
The data formats for RLHF and sequence classification in multimodal models can refer to the formats used in pure text large models.
127-
126+
The data format for RLHF and sequence classification of multimodal models can reference the format of pure text large models, with additional fields such as `images` added on top of that.
128127

129128
#### Grounding
130129

swift/llm/model/model/qwen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def get_model_tokenizer_qwen2_audio(*args, **kwargs):
650650

651651

652652
def get_model_tokenizer_ovis(*args, **kwargs):
653+
kwargs['attn_impl_keys'] = ['llm_attn_implementation']
653654
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
654655
model.visual_tokenizer.to(model.dtype)
655656
model.vte.to(model.dtype)

swift/llm/model/register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
237237
model_config = kwargs.get('model_config')
238238
if model_config is None:
239239
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
240-
AttnImpl.update_attn_impl(model_config, kwargs.get('attn_impl'))
240+
AttnImpl.update_attn_impl(model_config, kwargs.get('attn_impl'), kwargs.get('attn_impl_keys'))
241241
kwargs['model_config'] = model_config
242242
return get_model_tokenizer_from_local(model_dir, model_info, model_kwargs, load_model, **kwargs)
243243

swift/llm/model/utils.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union
77

88
import torch
9-
import transformers
109
from accelerate.utils import find_device
1110
from modelscope.hub.utils.utils import get_cache_dir
12-
from packaging import version
1311
from transformers import PretrainedConfig
1412

1513
from swift.hub import get_hub
@@ -26,25 +24,31 @@ class AttnImpl:
2624
sdpa = 'sdpa'
2725
eager = 'eager'
2826

27+
attn_impl_keys = ['_attn_implementation', 'attn_implementation', 'llm_attn_implementation']
28+
use_flash_attn_keys = ['_flash_attn_2_enabled', 'use_flash_attn', '_use_flash_attention_2']
29+
2930
@staticmethod
3031
def to_use_flash_attn(attn_impl: Optional[str], auto_value: _T = None) -> Union[bool, _T]:
3132
if attn_impl is None:
3233
return auto_value
3334
return attn_impl == AttnImpl.flash_attn
3435

3536
@staticmethod
36-
def update_attn_impl(config: PretrainedConfig, attn_impl: Optional[str], auto_value: _T = None) -> None:
37-
38-
use_flash_attn = AttnImpl.to_use_flash_attn(attn_impl, auto_value)
39-
if use_flash_attn is None:
37+
def update_attn_impl(config: PretrainedConfig,
38+
attn_impl: Optional[str],
39+
attn_impl_keys: Optional[List[str]] = None) -> None:
40+
if attn_impl is None:
4041
return
41-
from swift.llm import HfConfigFactory
42-
if version.parse(transformers.__version__) >= version.parse('4.36'):
43-
if use_flash_attn:
44-
attn_impl = 'flash_attention_2'
45-
HfConfigFactory.set_config_attr(config, '_attn_implementation', attn_impl)
46-
else:
47-
HfConfigFactory.set_config_attr(config, '_flash_attn_2_enabled', use_flash_attn)
42+
use_flash_attn = AttnImpl.to_use_flash_attn(attn_impl)
43+
if use_flash_attn:
44+
attn_impl = 'flash_attention_2'
45+
if isinstance(attn_impl_keys, str):
46+
attn_impl_keys = [attn_impl_keys]
47+
attn_impl_keys = attn_impl_keys or AttnImpl.attn_impl_keys
48+
for key in attn_impl_keys:
49+
HfConfigFactory.set_config_attr(config, key, attn_impl, ensure_set=False)
50+
for key in AttnImpl.use_flash_attn_keys:
51+
HfConfigFactory.set_config_attr(config, key, use_flash_attn, ensure_set=False)
4852

4953

5054
@dataclass
@@ -109,16 +113,20 @@ def get_config_attr(config: Union[PretrainedConfig, Dict[str, Any]], attr_name:
109113
return attrs[0][1]
110114

111115
@staticmethod
112-
def set_config_attr(config: Union[PretrainedConfig, Dict[str, Any]], attr_name: str, value: Any) -> None:
116+
def set_config_attr(config: Union[PretrainedConfig, Dict[str, Any]],
117+
attr_name: str,
118+
value: Any,
119+
ensure_set: bool = True) -> int:
113120
"""Set all the attr_name attributes to value."""
114121
attrs = HfConfigFactory._get_config_attrs(config, attr_name)
115-
if len(attrs) == 0:
122+
if ensure_set and len(attrs) == 0:
116123
attrs.append((config, None))
117124
for config, _ in attrs:
118125
if isinstance(config, dict):
119126
config[attr_name] = value
120127
else:
121128
setattr(config, attr_name, value)
129+
return len(attrs)
122130

123131
@staticmethod
124132
def set_model_config_attr(model, attr_name: str, value: Any) -> None:

tests/models/test_flash_attn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from swift.llm import get_model_tokenizer
2+
3+
if __name__ == '__main__':
4+
# model, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', attn_impl='flash_attn')
5+
# model, tokenizer = get_model_tokenizer('AIDC-AI/Ovis2-2B', attn_impl='flash_attn')
6+
# model, tokenizer = get_model_tokenizer('OpenGVLab/InternVL2-2B', attn_impl='flash_attn')
7+
model, tokenizer = get_model_tokenizer('Shanghai_AI_Laboratory/internlm3-8b-instruct', attn_impl='flash_attn')
8+
print(model)

tests/test_align/test_vllm_vlm.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ def _infer_image(model, use_chat_template: bool = True, max_model_len=8192, syst
3131
return resp_list[0].choices[0].message.content
3232

3333

34+
def _infer_video(model, use_chat_template: bool = True, max_model_len=8192, system=None):
35+
engine = VllmEngine(model, max_model_len=max_model_len, limit_mm_per_prompt={'image': 16, 'video': 2})
36+
if not use_chat_template:
37+
engine.default_template.use_chat_template = False
38+
videos = ['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4']
39+
messages = []
40+
if system is not None:
41+
messages += [{'role': 'system', 'content': system}]
42+
messages.append({'role': 'user', 'content': 'describe the video.'})
43+
resp_list = engine.infer([InferRequest(messages=messages, videos=videos)],
44+
RequestConfig(temperature=0, max_tokens=64, repetition_penalty=1.))
45+
return resp_list[0].choices[0].message.content
46+
47+
3448
def test_qwen2_audio():
3549
response = _infer_audio('Qwen/Qwen2-Audio-7B-Instruct')
3650
assert response == "The audio is a man speaking in Mandarin saying '今天天气真好呀'."
@@ -68,10 +82,48 @@ def test_internvl2():
6882
'and it appears to be looking directly at the camera. The fur is soft and fluffy, with a mix')
6983

7084

85+
def test_minicpmv_2_5():
86+
response = _infer_image('OpenBMB/MiniCPM-Llama3-V-2_5', max_model_len=4096)
87+
assert response == (
88+
"The image is a digital painting of a kitten that captures the essence of a young feline's innocence "
89+
"and curiosity. The kitten's fur is rendered with a mix of gray, white, and black stripes, "
90+
'giving it a realistic and adorable appearance. Its large, expressive eyes are a striking blue, '
91+
"which draws the viewer's")
92+
93+
94+
def test_minicpmv_2_6():
95+
response = _infer_image('OpenBMB/MiniCPM-V-2_6', max_model_len=4096)
96+
assert response == (
97+
'The image features a close-up of a kitten with striking blue eyes and a mix of '
98+
"white and dark fur, possibly gray or black. The kitten's gaze is directed forward, giving it an "
99+
"expressive and captivating look. The background is blurred, drawing focus to the kitten's face. "
100+
"The overall composition emphasizes the kitten's features")
101+
102+
103+
def test_minicpmo_2_6_video():
104+
response = _infer_video('OpenBMB/MiniCPM-o-2_6')
105+
assert response == ('The video features a young child sitting on a bed, deeply engaged in reading a book. '
106+
'The child, dressed in a light blue sleeveless top and pink pants, is surrounded by a '
107+
'cozy and homely environment. The bed is adorned with a patterned blanket, and a white cloth '
108+
'is casually draped over the side.')
109+
110+
111+
def test_qwen2_5_vl_video():
112+
response = _infer_video('Qwen/Qwen2.5-VL-3B-Instruct')
113+
assert response == ('A baby wearing sunglasses is sitting on a bed and reading a book. '
114+
'The baby is holding the book with both hands and is looking at the pages. '
115+
'The baby is wearing a light blue shirt and pink pants. The baby is sitting '
116+
'on a white blanket. The baby is looking at the book and is smiling. The baby')
117+
118+
71119
if __name__ == '__main__':
72120
from swift.llm import VllmEngine, InferRequest, RequestConfig
73121
# test_qwen2_vl()
74122
# test_qwen2_5_vl()
75123
# test_deepseek_vl_v2()
76124
# test_internvl2()
77-
test_qwen2_audio()
125+
# test_qwen2_audio()
126+
# test_minicpmv_2_5()
127+
# test_minicpmv_2_6()
128+
test_minicpmo_2_6_video()
129+
# test_qwen2_5_vl_video()

0 commit comments

Comments
 (0)